Repository: mindspore-ai/serving Branch: master Commit: df1fc19f8e4a Files: 473 Total size: 2.7 MB Directory structure: gitextract_vl65tc6t/ ├── .clang-format ├── .gitee/ │ └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .gitmodules ├── .jenkins/ │ └── test/ │ └── config/ │ └── dependent_packages.yaml ├── CMakeLists.txt ├── LICENSE ├── NOTICE ├── OWNERS ├── README.md ├── README_CN.md ├── RELEASE.md ├── RELEASE_CN.md ├── Third_Party_Open_Source_Software_Notice ├── build.sh ├── cmake/ │ ├── check_requirements.cmake │ ├── dependency_ms.cmake │ ├── dependency_securec.cmake │ ├── dependency_utils.cmake │ ├── external_libs/ │ │ ├── absl.cmake │ │ ├── c-ares.cmake │ │ ├── eigen.cmake │ │ ├── glog.cmake │ │ ├── grpc.cmake │ │ ├── gtest.cmake │ │ ├── json.cmake │ │ ├── libevent.cmake │ │ ├── openssl.cmake │ │ ├── protobuf.cmake │ │ ├── pybind11.cmake │ │ ├── re2.cmake │ │ └── zlib.cmake │ ├── mind_expression.cmake │ ├── options.cmake │ ├── package.cmake │ ├── package_script.cmake │ └── utils.cmake ├── docs/ │ └── api/ │ └── api_python/ │ ├── client/ │ │ ├── mindspore_serving.client.Client.rst │ │ ├── mindspore_serving.client.SSLConfig.rst │ │ └── mindspore_serving.client.rst │ ├── mindspore_serving.client.rst │ ├── mindspore_serving.server.rst │ └── server/ │ ├── distributed/ │ │ ├── mindspore_serving.server.distributed.declare_servable.rst │ │ ├── mindspore_serving.server.distributed.rst │ │ ├── mindspore_serving.server.distributed.start_servable.rst │ │ └── mindspore_serving.server.distributed.startup_agents.rst │ ├── mindspore_serving.server.SSLConfig.rst │ ├── mindspore_serving.server.ServableStartConfig.rst │ ├── mindspore_serving.server.rst │ ├── mindspore_serving.server.start_grpc_server.rst │ ├── mindspore_serving.server.start_restful_server.rst │ ├── mindspore_serving.server.start_servables.rst │ ├── mindspore_serving.server.stop.rst │ └── register/ │ ├── mindspore_serving.server.register.AscendDeviceInfo.rst │ ├── mindspore_serving.server.register.CPUDeviceInfo.rst │ ├── mindspore_serving.server.register.Context.rst │ ├── mindspore_serving.server.register.GPUDeviceInfo.rst │ ├── mindspore_serving.server.register.Model.rst │ ├── mindspore_serving.server.register.add_stage.rst │ ├── mindspore_serving.server.register.declare_model.rst │ ├── mindspore_serving.server.register.register_method.rst │ └── mindspore_serving.server.register.rst ├── engine/ │ └── README.md ├── example/ │ ├── add_sub_pipeline/ │ │ ├── add_sub/ │ │ │ └── servable_config.py │ │ ├── export_model/ │ │ │ └── add_sub_model.py │ │ ├── serving_client.py │ │ └── serving_server.py │ ├── lenet/ │ │ ├── export_model/ │ │ │ ├── export_lenet.py │ │ │ └── lenet/ │ │ │ ├── __init__.py │ │ │ ├── export.py │ │ │ └── src/ │ │ │ └── lenet.py │ │ ├── lenet/ │ │ │ └── servable_config.py │ │ ├── serving_client.py │ │ └── serving_server.py │ ├── matmul_distributed/ │ │ ├── export_model/ │ │ │ ├── distributed_inference.py │ │ │ ├── export_model.sh │ │ │ ├── net.py │ │ │ └── rank_table_8pcs.json │ │ ├── matmul/ │ │ │ └── servable_config.py │ │ ├── rank_table_8pcs.json │ │ ├── serving_agent.py │ │ ├── serving_client.py │ │ └── serving_server.py │ ├── matmul_multi_subgraphs/ │ │ ├── export_model/ │ │ │ └── export_matmul.py │ │ ├── matmul/ │ │ │ └── servable_config.py │ │ ├── serving_client.py │ │ └── serving_server.py │ ├── resnet/ │ │ ├── export_model/ │ │ │ ├── export_resnet.py │ │ │ └── resnet/ │ │ │ ├── __init__.py │ │ │ ├── export.py │ │ │ └── src/ │ │ │ ├── config.py │ │ │ └── resnet.py │ │ ├── resnet50/ │ │ │ └── servable_config.py │ │ ├── serving_client.py │ │ └── serving_server.py │ └── tensor_add/ │ ├── add/ │ │ └── servable_config.py │ ├── export_model/ │ │ └── add_model.py │ ├── serving_client.py │ ├── serving_client_with_check.py │ └── serving_server.py ├── mindspore_serving/ │ ├── CMakeLists.txt │ ├── __init__.py │ ├── ccsrc/ │ │ ├── common/ │ │ │ ├── buffer_tensor.cc │ │ │ ├── buffer_tensor.h │ │ │ ├── exit_handle.cc │ │ │ ├── exit_handle.h │ │ │ ├── float16.h │ │ │ ├── grpc_async_server.h │ │ │ ├── grpc_client.cc │ │ │ ├── grpc_client.h │ │ │ ├── grpc_server.cc │ │ │ ├── grpc_server.h │ │ │ ├── heart_beat.cc │ │ │ ├── heart_beat.h │ │ │ ├── instance.h │ │ │ ├── instance_data.h │ │ │ ├── log.cc │ │ │ ├── log.h │ │ │ ├── proto_tensor.cc │ │ │ ├── proto_tensor.h │ │ │ ├── servable.cc │ │ │ ├── servable.h │ │ │ ├── serving_common.h │ │ │ ├── shared_memory.cc │ │ │ ├── shared_memory.h │ │ │ ├── ssl_config.h │ │ │ ├── status.h │ │ │ ├── tensor.cc │ │ │ ├── tensor.h │ │ │ ├── tensor_base.cc │ │ │ ├── tensor_base.h │ │ │ ├── thread_pool.cc │ │ │ ├── thread_pool.h │ │ │ ├── utils.cc │ │ │ └── utils.h │ │ ├── master/ │ │ │ ├── dispacther.cc │ │ │ ├── dispacther.h │ │ │ ├── grpc/ │ │ │ │ ├── grpc_process.cc │ │ │ │ ├── grpc_process.h │ │ │ │ ├── grpc_server.cc │ │ │ │ ├── grpc_server.h │ │ │ │ └── master_server.h │ │ │ ├── master_context.cc │ │ │ ├── master_context.h │ │ │ ├── model_thread.cc │ │ │ ├── model_thread.h │ │ │ ├── notify_worker/ │ │ │ │ ├── base_notify.h │ │ │ │ ├── grpc_notify.cc │ │ │ │ └── grpc_notify.h │ │ │ ├── restful/ │ │ │ │ ├── http_handle.cc │ │ │ │ ├── http_handle.h │ │ │ │ ├── http_process.cc │ │ │ │ ├── http_process.h │ │ │ │ ├── restful_request.cc │ │ │ │ ├── restful_request.h │ │ │ │ ├── restful_server.cc │ │ │ │ └── restful_server.h │ │ │ ├── servable_endpoint.cc │ │ │ ├── servable_endpoint.h │ │ │ ├── server.cc │ │ │ ├── server.h │ │ │ ├── worker_context.cc │ │ │ └── worker_context.h │ │ ├── python/ │ │ │ ├── agent/ │ │ │ │ ├── agent_py.cc │ │ │ │ └── agent_py.h │ │ │ ├── master/ │ │ │ │ ├── master_py.cc │ │ │ │ └── master_py.h │ │ │ ├── serving_py.cc │ │ │ ├── tensor_py.cc │ │ │ ├── tensor_py.h │ │ │ └── worker/ │ │ │ ├── servable_py.cc │ │ │ ├── servable_py.h │ │ │ ├── worker_py.cc │ │ │ └── worker_py.h │ │ └── worker/ │ │ ├── context.cc │ │ ├── context.h │ │ ├── distributed_worker/ │ │ │ ├── agent_process/ │ │ │ │ ├── agent_process.cc │ │ │ │ └── agent_process.h │ │ │ ├── agent_startup.cc │ │ │ ├── agent_startup.h │ │ │ ├── common.h │ │ │ ├── distributed_model_loader.cc │ │ │ ├── distributed_model_loader.h │ │ │ ├── distributed_process/ │ │ │ │ ├── distributed_process.cc │ │ │ │ ├── distributed_process.h │ │ │ │ └── distributed_server.h │ │ │ ├── notify_agent/ │ │ │ │ ├── base_notify_agent.h │ │ │ │ ├── notify_agent.cc │ │ │ │ └── notify_agent.h │ │ │ ├── notify_distributed/ │ │ │ │ ├── notify_worker.cc │ │ │ │ └── notify_worker.h │ │ │ ├── worker_agent.cc │ │ │ └── worker_agent.h │ │ ├── extra_worker/ │ │ │ ├── remote_call_model.cc │ │ │ └── remote_call_model.h │ │ ├── grpc/ │ │ │ ├── worker_process.cc │ │ │ ├── worker_process.h │ │ │ ├── worker_server.cc │ │ │ └── worker_server.h │ │ ├── inference/ │ │ │ ├── inference.cc │ │ │ ├── inference.h │ │ │ ├── mindspore_model_wrap.cc │ │ │ └── mindspore_model_wrap.h │ │ ├── local_servable/ │ │ │ ├── local_model_loader.cc │ │ │ └── local_model_loader.h │ │ ├── model_loader_base.cc │ │ ├── model_loader_base.h │ │ ├── notfiy_master/ │ │ │ ├── base_notify.h │ │ │ ├── grpc_notify.cc │ │ │ └── grpc_notify.h │ │ ├── predict_thread.cc │ │ ├── predict_thread.h │ │ ├── register/ │ │ │ └── argmax.cc │ │ ├── servable_register.cc │ │ ├── servable_register.h │ │ ├── stage_function.cc │ │ ├── stage_function.h │ │ ├── task_queue.cc │ │ ├── task_queue.h │ │ ├── work_executor.cc │ │ ├── work_executor.h │ │ ├── worker.cc │ │ └── worker.h │ ├── client/ │ │ ├── __init__.py │ │ ├── cpp/ │ │ │ ├── client.cc │ │ │ └── client.h │ │ └── python/ │ │ ├── __init__.py │ │ └── client.py │ ├── log.py │ ├── proto/ │ │ ├── ms_agent.proto │ │ ├── ms_distributed.proto │ │ ├── ms_master.proto │ │ ├── ms_service.proto │ │ └── ms_worker.proto │ └── server/ │ ├── __init__.py │ ├── _servable_common.py │ ├── _servable_local.py │ ├── _server.py │ ├── common/ │ │ ├── __init__.py │ │ ├── check_type.py │ │ ├── decorator.py │ │ └── utils.py │ ├── distributed/ │ │ ├── __init__.py │ │ ├── _distributed.py │ │ ├── _servable_distributed.py │ │ └── start_distributed_worker.py │ ├── master/ │ │ ├── __init__.py │ │ ├── _master.py │ │ └── context.py │ ├── register/ │ │ ├── __init__.py │ │ ├── method.py │ │ ├── model.py │ │ ├── stage_function.py │ │ └── utils.py │ ├── start_extra_worker.py │ ├── start_worker.py │ └── worker/ │ ├── __init__.py │ ├── _worker.py │ ├── check_version.py │ ├── distributed/ │ │ ├── __init__.py │ │ ├── agent_startup.py │ │ ├── distributed_worker.py │ │ ├── register.py │ │ └── worker_agent.py │ ├── init_mindspore.py │ └── task.py ├── requirements_test.txt ├── scripts/ │ ├── check_clang_format.sh │ └── format_source_code.sh ├── setup.py ├── tests/ │ ├── CMakeLists.txt │ ├── st/ │ │ ├── add/ │ │ │ ├── __init__.py │ │ │ ├── add.sh │ │ │ └── test_serving.py │ │ ├── add_sub_pipeline/ │ │ │ ├── __init__.py │ │ │ ├── add_sub.sh │ │ │ └── test_serving.py │ │ ├── distributed_server_fault/ │ │ │ ├── __init__.py │ │ │ ├── common.sh │ │ │ ├── kill_15_agent.sh │ │ │ ├── kill_15_server.sh │ │ │ ├── kill_9_agent.sh │ │ │ ├── kill_9_server.sh │ │ │ └── test_distributed_fault.py │ │ ├── matmul_distributed/ │ │ │ ├── __init__.py │ │ │ ├── matmul_distribute.sh │ │ │ └── test_matmul_distribute.py │ │ ├── matmul_multi_subgraphs/ │ │ │ ├── __init__.py │ │ │ ├── matmul_multi_subgraphs.sh │ │ │ └── test_matmul_multi_subgraphs.py │ │ ├── resnet/ │ │ │ ├── __init__.py │ │ │ ├── resnet.sh │ │ │ └── test_resnet.py │ │ └── serving_fault/ │ │ ├── __init__.py │ │ ├── common.sh │ │ ├── kill_15_master.sh │ │ ├── kill_15_worker.sh │ │ ├── kill_9_master.sh │ │ ├── kill_9_worker.sh │ │ ├── restart.sh │ │ └── test_serving_fault.py │ └── ut/ │ ├── CMakeLists.txt │ ├── coverage/ │ │ ├── cov_config │ │ └── run_coverage.sh │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── common/ │ │ │ ├── common_test.cc │ │ │ ├── common_test.h │ │ │ ├── test_main.cc │ │ │ └── test_servable_common.h │ │ ├── runtest.sh │ │ └── tests/ │ │ ├── test_agent_config_acquire.cc │ │ ├── test_context.cc │ │ ├── test_distributed_inference.cc │ │ ├── test_init_config_on_start_up.cc │ │ ├── test_master_worker.cc │ │ ├── test_model_thread.cc │ │ ├── test_parse_restful.cc │ │ ├── test_shared_memory.cc │ │ ├── test_start_preprocess_postprocess.cc │ │ └── test_start_worker.cc │ ├── python/ │ │ ├── CMakeLists.txt │ │ ├── mindspore/ │ │ │ └── dataset/ │ │ │ └── __init__.py │ │ ├── runtest.sh │ │ ├── servable_config/ │ │ │ ├── add_servable_config.py │ │ │ └── generate_certs.sh │ │ └── tests/ │ │ ├── common.py │ │ ├── common_restful.py │ │ ├── test_distributed_worker.py │ │ ├── test_grpc_request.py │ │ ├── test_model_call.py │ │ ├── test_model_context.py │ │ ├── test_multi_model.py │ │ ├── test_python_parallel.py │ │ ├── test_register_method.py │ │ ├── test_restful_base64_data.py │ │ ├── test_restful_json_data.py │ │ ├── test_restful_request.py │ │ ├── test_server_client.py │ │ ├── test_serving_log.py │ │ ├── test_stage_function.py │ │ ├── test_start_servable_config.py │ │ └── test_start_sevables.py │ ├── runtest.sh │ └── stub/ │ ├── cxx_api/ │ │ ├── cell.cc │ │ ├── context.cc │ │ ├── factory.h │ │ ├── graph/ │ │ │ ├── ascend/ │ │ │ │ ├── ascend_graph_impl.cc │ │ │ │ └── ascend_graph_impl.h │ │ │ ├── graph.cc │ │ │ ├── graph_data.cc │ │ │ ├── graph_data.h │ │ │ └── graph_impl.h │ │ ├── model/ │ │ │ ├── model.cc │ │ │ ├── model_impl.cc │ │ │ ├── model_impl.h │ │ │ └── ms/ │ │ │ ├── ms_model.cc │ │ │ └── ms_model.h │ │ ├── serialization.cc │ │ ├── status.cc │ │ └── types.cc │ ├── graph_impl_stub.cc │ ├── graph_impl_stub.h │ ├── include/ │ │ ├── api/ │ │ │ ├── allocator.h │ │ │ ├── callback/ │ │ │ │ ├── callback.h │ │ │ │ ├── ckpt_saver.h │ │ │ │ ├── loss_monitor.h │ │ │ │ ├── lr_scheduler.h │ │ │ │ ├── time_monitor.h │ │ │ │ └── train_accuracy.h │ │ │ ├── cell.h │ │ │ ├── cfg.h │ │ │ ├── context.h │ │ │ ├── data_type.h │ │ │ ├── delegate.h │ │ │ ├── dual_abi_helper.h │ │ │ ├── format.h │ │ │ ├── graph.h │ │ │ ├── kernel.h │ │ │ ├── metrics/ │ │ │ │ ├── accuracy.h │ │ │ │ └── metrics.h │ │ │ ├── model.h │ │ │ ├── model_parallel_runner.h │ │ │ ├── ops/ │ │ │ │ └── ops.h │ │ │ ├── serialization.h │ │ │ ├── status.h │ │ │ ├── types.h │ │ │ └── visible.h │ │ ├── mindapi/ │ │ │ └── base/ │ │ │ ├── format.h │ │ │ ├── type_id.h │ │ │ └── types.h │ │ └── utils/ │ │ ├── log_adapter.cc │ │ ├── log_adapter.h │ │ ├── log_adapter_common.cc │ │ ├── overload.h │ │ ├── utils.h │ │ └── visible.h │ ├── stub_inference.cc │ ├── stub_postprocess.cc │ └── stub_preprocess.cc └── third_party/ ├── patch/ │ ├── c-ares/ │ │ └── CVE-2021-3672.patch │ ├── glog/ │ │ └── glog.patch001 │ ├── grpc/ │ │ └── grpc.patch001 │ ├── libevent/ │ │ └── libevent.patch001 │ ├── openssl/ │ │ ├── CVE-2021-3711.patch │ │ ├── CVE-2021-3712.patch │ │ ├── CVE-2021-4160.patch │ │ ├── CVE-2022-0778.patch │ │ ├── CVE-2022-1292.patch │ │ ├── CVE-2022-2068.patch │ │ ├── CVE-2022-2097.patch │ │ ├── CVE-2022-4304.patch │ │ ├── CVE-2022-4450.patch │ │ ├── CVE-2023-0215.patch │ │ ├── CVE-2023-0286.patch │ │ ├── CVE-2023-0464.patch │ │ ├── CVE-2023-0465.patch │ │ ├── CVE-2023-0466.patch │ │ ├── CVE-2023-2650.patch │ │ ├── CVE-2023-3446.patch │ │ └── CVE-2023-4807.patch │ ├── protobuf/ │ │ ├── CVE-2021-22570.patch │ │ └── CVE-2022-1941.patch │ ├── pybind11/ │ │ └── pybind11.patch001 │ └── zlib/ │ ├── CVE-2018-25032.patch │ └── CVE-2022-37434.patch └── securec/ ├── CMakeLists.txt ├── include/ │ ├── securec.h │ └── securectype.h └── src/ ├── CMakeLists.txt ├── fscanf_s.c ├── fwscanf_s.c ├── gets_s.c ├── input.inl ├── memcpy_s.c ├── memmove_s.c ├── memset_s.c ├── output.inl ├── scanf_s.c ├── secinput.h ├── securecutil.c ├── securecutil.h ├── secureinput_a.c ├── secureinput_w.c ├── secureprintoutput.h ├── secureprintoutput_a.c ├── secureprintoutput_w.c ├── snprintf_s.c ├── sprintf_s.c ├── sscanf_s.c ├── strcat_s.c ├── strcpy_s.c ├── strncat_s.c ├── strncpy_s.c ├── strtok_s.c ├── swprintf_s.c ├── swscanf_s.c ├── vfscanf_s.c ├── vfwscanf_s.c ├── vscanf_s.c ├── vsnprintf_s.c ├── vsprintf_s.c ├── vsscanf_s.c ├── vswprintf_s.c ├── vswscanf_s.c ├── vwscanf_s.c ├── wcscat_s.c ├── wcscpy_s.c ├── wcsncat_s.c ├── wcsncpy_s.c ├── wcstok_s.c ├── wmemcpy_s.c ├── wmemmove_s.c └── wscanf_s.c ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ --- Language: Cpp # BasedOnStyle: Google AccessModifierOffset: -1 AlignAfterOpenBracket: Align AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false AlignEscapedNewlines: Left AlignOperands: true AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: All AllowShortIfStatementsOnASingleLine: true AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: true AlwaysBreakTemplateDeclarations: Yes BinPackArguments: true BinPackParameters: true BraceWrapping: AfterClass: false AfterControlStatement: false AfterEnum: false AfterFunction: false AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false AfterExternBlock: false BeforeCatch: false BeforeElse: false IndentBraces: false SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true BreakBeforeBinaryOperators: None BreakBeforeBraces: Attach BreakBeforeInheritanceComma: false BreakInheritanceList: BeforeColon BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakConstructorInitializers: BeforeColon BreakAfterJavaFieldAnnotations: false BreakStringLiterals: true ColumnLimit: 120 CommentPragmas: '^ IWYU pragma:' CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 2 Cpp11BracedListStyle: true DerivePointerAlignment: false DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true ForEachMacros: # - foreach - Q_FOREACH - BOOST_FOREACH IncludeBlocks: Preserve IncludeCategories: - Regex: '^' Priority: 2 - Regex: '^<.*\.h>' Priority: 1 - Regex: '^<.*' Priority: 2 - Regex: '.*' Priority: 3 IncludeIsMainRegex: '([-_](test|unittest))?$' IndentCaseLabels: true IndentPPDirectives: None IndentWidth: 2 IndentWrappedFunctionNames: false JavaScriptQuotes: Leave JavaScriptWrapImports: true KeepEmptyLinesAtTheStartOfBlocks: false MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None ObjCBinPackProtocolList: Never ObjCBlockIndentWidth: 2 ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: true PenaltyBreakAssignment: 2 PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: - cc - CC - cpp - Cpp - CPP - 'c++' - 'C++' CanonicalDelimiter: '' BasedOnStyle: google - Language: TextProto Delimiters: - pb - PB - proto - PROTO EnclosingFunctions: - EqualsProto - EquivToProto - PARSE_PARTIAL_TEXT_PROTO - PARSE_TEST_PROTO - PARSE_TEXT_PROTO - ParseTextOrDie - ParseTextProtoOrDie CanonicalDelimiter: '' BasedOnStyle: google ReflowComments: true SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: true SpaceBeforeAssignmentOperators: true SpaceBeforeCpp11BracedList: false SpaceBeforeCtorInitializerColon: true SpaceBeforeInheritanceColon: true SpaceBeforeParens: ControlStatements SpaceBeforeRangeBasedForLoopColon: true SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 2 SpacesInAngles: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false Standard: Auto StatementMacros: - Q_UNUSED - QT_REQUIRE_VERSION TabWidth: 2 UseTab: Never SortIncludes: false ... ================================================ FILE: .gitee/PULL_REQUEST_TEMPLATE.md ================================================ **What type of PR is this?** > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: > > /kind bug > /kind task > /kind feature **What does this PR do / why do we need it**: **Which issue(s) this PR fixes**: Fixes # **Special notes for your reviewers**: ================================================ FILE: .gitignore ================================================ # MindSpore Serving build/ mindspore_serving/lib output *.ir .coverage* htmlcov/ cov_output/ # Cmake files CMakeFiles/ cmake_install.cmake CMakeCache.txt Makefile cmake-build-debug # Prerequisites *.d # Compiled Object files *.slo *.lo *.o *.obj # Precompiled Headers *.gch *.pch # Compiled Dynamic libraries *.so *.dylib *.dll *.so.* # Fortran module files *.mod *.smod # Compiled Static libraries *.lai *.la *.a *.lib # Executables *.exe *.out *.app # Protocol buffers *_pb2.py *.pb.h *.pb.cc *.pb *_grpc.py # Editor .vscode .idea/ # Cquery .cquery_cached_index/ compile_commands.json # Ctags and cscope tags TAGS CTAGS GTAGS GRTAGS GSYMS GPATH cscope.* # Python files *__pycache__* .pytest_cache # Mac files *.DS_Store # Test results test_temp_summary_event_file/ *.dot *.dat *.svg *.perf *.info *.ckpt *.shp *.pkl .clangd mindspore_serving/version.py mindspore_serving/default_config.py mindspore_serving/.commit_id tests/ut/python/tests/ca.crt tests/ut/python/tests/ca.key tests/ut/python/tests/ca.srl tests/ut/python/tests/server.crt tests/ut/python/tests/server.csr tests/ut/python/tests/server.key tests/ut/python/tests/client.crt tests/ut/python/tests/client.csr tests/ut/python/tests/client.key tests/ut/python/tests/serving_logs/ tests/ut/python/tests/unix_socket_files/ tests/ut/python/tests/serving_python_ut_servables/ ================================================ FILE: .gitmodules ================================================ [submodule "third_party/mindspore"] path = third_party/mindspore url = https://gitee.com/mindspore/mindspore.git ================================================ FILE: .jenkins/test/config/dependent_packages.yaml ================================================ mindspore: 'mindspore/mindspore/version/202310/20231010/master_20231010144855_e5008bcfa07e3e6f3fa50f3ba0ac90175504dfd7/' ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.14.1) project(MindSpore_Serving) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") endif() include(${CMAKE_SOURCE_DIR}/cmake/options.cmake) # set compile options include(${CMAKE_SOURCE_DIR}/cmake/check_requirements.cmake) # check require party, like OpenSSL set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 \ -D_FORTIFY_SOURCE=2") if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) endif() if(ENABLE_PYTHON) add_compile_definitions(ENABLE_PYTHON) endif() set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \ -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \ -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -fPIC") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(PYBIND11_CPP_STANDARD -std=c++17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPTION_CXX_FLAGS}") # compile third party: grpc, libevent, gtest, onnx include(${CMAKE_SOURCE_DIR}/cmake/mind_expression.cmake) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/securec/include) # find python3 packages include(${CMAKE_SOURCE_DIR}/cmake/dependency_utils.cmake) find_package(Python3 3.7 COMPONENTS Interpreter Development) if(Python3_FOUND) set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}") set(PYTHON_LIBRARIES "${Python3_LIBRARIES}") if(WIN32) if(Python3_DIR) message("Python3_DIR set already: " ${Python3_DIR}) else() string(LENGTH ${PYTHON_LIBRARIES} PYTHON_LIBRARIES_LEN) string(LENGTH "libpythonxx.a" Python3_NAME_LEN) math(EXPR Python3_DIR_LEN ${PYTHON_LIBRARIES_LEN}-${Python3_NAME_LEN}) string(SUBSTRING ${Python3_LIBRARIES} 0 ${Python3_DIR_LEN} Python3_DIR) message("Python3_DIR: " ${Python3_DIR}) endif() link_directories(${Python3_DIR}) endif() else() find_python_package(py_inc py_lib) set(PYTHON_INCLUDE_DIRS "${py_inc}") set(PYTHON_LIBRARIES "${py_lib}") endif() message("PYTHON_INCLUDE_DIRS = ${PYTHON_INCLUDE_DIRS}") message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") include_directories(${PYTHON_INCLUDE_DIRS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") find_package(Threads REQUIRED) if(ENABLE_TESTCASES) add_subdirectory(tests) endif() add_subdirectory(mindspore_serving) include(cmake/package.cmake) ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: NOTICE ================================================ MindSpore Serving Copyright 2020 Huawei Technologies Co., Ltd ================================================ FILE: OWNERS ================================================ approvers: - zhaizhiqiang - zhangxuetong - hangangqiang ================================================ FILE: README.md ================================================ # MindSpore Serving [查看中文](./README_CN.md) - [MindSpore Serving](#mindspore-serving) - [Overview](#overview) - [Installation](#installation) - [Installing Serving](#installing-serving) - [Configuring Environment Variables](#configuring-environment-variables) - [Quick Start](#quick-start) - [Documents](#documents) - [Developer Guide](#developer-guide) - [Community](#community) - [Governance](#governance) - [Communication](#communication) - [Contributions](#contributions) - [Release Notes](#release-notes) - [License](#license) ## Overview MindSpore Serving is a lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment. After completing model training on MindSpore, you can export the MindSpore model and use MindSpore Serving to create an inference service for the model. MindSpore Serving architecture: MindSpore Architecture MindSpore Serving includes two parts: `Client` and `Server`. On a `Client` node, you can deliver inference service commands through the gRPC or RESTful API. The `Server` consists of a `Main` node and one or more `Worker` nodes. The `Main` node manages all `Worker` nodes and their model information, accepts user requests from `Client`s, and distributes the requests to `Worker` nodes. `Servable` is deployed on a worker node, indicates a single model or a combination of multiple models and can provide different services in various methods. ` On the server side, when [MindSpore](#https://www.mindspore.cn/) is used as the inference backend,, MindSpore Serving supports the Ascend 910 and Nvidia GPU environments. When [MindSpore Lite](#https://www.mindspore.cn/lite) is used as the inference backend, MindSpore Serving supports Ascend 310/310P, Nvidia GPU and CPU environments. Client` does not depend on specific hardware platforms. MindSpore Serving provides the following functions: - gRPC and RESTful APIs on clients - Pre-processing and post-processing of assembled models - Batch. Multiple instance requests are split and combined to meet the `batch size` requirement of the model. - Simple Python APIs on clients - The multi-model combination is supported. The multi-model combination and single-model scenarios use the same set of interfaces. - Distributed model inference ## Installation For details about how to install and configure MindSpore Serving, see the [MindSpore Serving installation page](https://www.mindspore.cn/serving/docs/en/master/serving_install.html). ## Quick Start [MindSpore-based Inference Service Deployment](https://www.mindspore.cn/serving/docs/en/master/serving_example.html) is used to demonstrate how to use MindSpore Serving. ## Documents ### Developer Guide - [gRPC-based MindSpore Serving Access](https://www.mindspore.cn/serving/docs/en/master/serving_grpc.html) - [RESTful-based MindSpore Serving Access](https://www.mindspore.cn/serving/docs/en/master/serving_restful.html) - [Services Provided Through Model Configuration](https://www.mindspore.cn/serving/docs/en/master/serving_model.html) - [Services Composed of Multiple Models](https://www.mindspore.cn/serving/docs/en/master/serving_model.html#services-composed-of-multiple-models) - [MindSpore Serving-based Distributed Inference Service Deployment](https://www.mindspore.cn/serving/docs/en/master/serving_distributed_example.html) For more details about the installation guide, tutorials, and APIs, see [MindSpore Python API](https://www.mindspore.cn/serving/docs/en/master/server.html). ## Community ### Governance [MindSpore Open Governance](https://gitee.com/mindspore/community/blob/master/governance.md) ### Communication - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/zt-dgk65rli-3ex4xvS4wHX7UDmsQmfu8w) developer communication platform ## Contributions Welcome to MindSpore contribution. ## Release Notes [RELEASE](RELEASE.md) ## License [Apache License 2.0](LICENSE) ================================================ FILE: README_CN.md ================================================ # MindSpore Serving [View English](./README.md) - [MindSpore Serving](#mindspore-serving) - [概述](#概述) - [安装](#安装) - [安装Serving](#安装serving) - [配置环境变量](#配置环境变量) - [快速入门](#快速入门) - [文档](#文档) - [开发者教程](#开发者教程) - [社区](#社区) - [治理](#治理) - [交流](#交流) - [贡献](#贡献) - [版本说明](#版本说明) - [许可证](#许可证) ## 概述 MindSpore Serving是一个轻量级、高性能的服务模块,旨在帮助MindSpore开发者在生产环境中高效部署在线推理服务。当用户使用MindSpore完成模型训练 后,导出MindSpore模型,即可使用MindSpore Serving创建该模型的推理服务。 MindSpore Serving架构: MindSpore Architecture MindSpore Serving分为客户端、服务器两个部分。在客户端中,用户通过gRPC或RESTful接口向服务器下发推理服务命令。服务器包括主(`Main`)节点和 一个或多个工作(`Worker`)节点,主节点管理所有的工作节点及其部署的模型信息,接受客户端的用户请求,并将请求分发给工作节点。每个工作节点部署了 一个可服务对象,即`Servable`,这里的`Servable`可以是单个模型,也可以是多个模型的组合,一个`Servable`可以围绕相同的模型通过多种方法来提供 不同的服务。 对于服务端,当以[MindSpore](#https://www.mindspore.cn/)作为推理后端时,MindSpore Serving当前支持Ascend 910和Nvidia GPU环境。当以[MindSpore Lite](#https://www.mindspore.cn/lite)作为推理后端时,MindSpore Serving当前支持Ascend 310/310P、Nvidia GPU和CPU。客户端不依赖特定硬件平台。 MindSpore Serving提供以下功能: - 支持客户端gRPC和RESTful接口。 - 支持组装模型的前处理和后处理。 - 支持batch功能,多实例请求会被拆分组合以满足模型`batch size`的需要。 - 提供客户端Python简易接口。 - 支持多模型组合,多模型组合和单模型场景使用相同的一套接口。 - 支持分布式模型推理功能。 ## 安装 MindSpore Serving安装和配置可以参考[MindSpore Serving安装页面](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_install.html)。 ## 快速入门 以一个简单的[Add网络示例](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_example.html),演示MindSpore Serving如何使用。 ## 文档 ### 开发者教程 - [基于gRPC接口访问MindSpore Serving服务](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_grpc.html) - [基于RESTful接口访问MindSpore Serving服务](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_restful.html) - [配置模型提供服务](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_model.html) - [配置多模型组合的服务](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_model.html#id9) - [基于MindSpore Serving部署分布式推理服务](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_distributed_example.html) 有关安装指南、教程和API的更多详细信息,请参阅[用户文档](https://www.mindspore.cn/serving/docs/zh-CN/master/server.html)。 ## 社区 ### 治理 查看MindSpore如何进行[开放治理](https://gitee.com/mindspore/community/blob/master/governance.md)。 ### 交流 - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/zt-dgk65rli-3ex4xvS4wHX7UDmsQmfu8w) 开发者交流平台。 ## 贡献 欢迎参与贡献。 ## 版本说明 版本说明请参阅[RELEASE](RELEASE.md)。 ## 许可证 [Apache License 2.0](LICENSE) ================================================ FILE: RELEASE.md ================================================ # MindSpore Serving Release Notes [查看中文](./RELEASE_CN.md) ## MindSpore Serving 2.0.2 Release Notes ### Major Features and Improvements - Released based on MindSpore 2.2.0. - Fix third-party OpenSSL vulnerabilities: CVE-2023-3446 and CVE-2023-4807. ### Contributors Thanks goes to these wonderful people: qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 2.0.0 Release Notes ### Major Features and Improvements - Released based on MindSpore 2.0.0rc1. - Fix third-party OpenSSL vulnerabilities: CVE-2022-4304, CVE-2022-4450, CVE-2022-4450, CVE-2023-0286, CVE-2023-0464, CVE-2023-0465 and CVE-2023-0466. ### Contributors Thanks goes to these wonderful people: qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.8.0 Release Notes ### Major Features and Improvements - [STABLE] When deploying a large-scale model with parallel pipeline, Serving supports parallel pipeline processing of multiple inference instances. ### Contributors Thanks goes to these wonderful people: qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.7.0 Release Notes ### Major Features and Improvements - [DEMO] Ascend 310P can be used as the inference device, for more detail see [MindSpore Serving backend](https://www.mindspore.cn/serving/docs/en/master/serving_install.html#installation). - [DEMO] Support models of MindIR format when MindSpore Lite is used as the MindSpore Serving inference backend, for more detail see [MindSpore Serving backend](https://www.mindspore.cn/serving/docs/en/master/serving_install.html#installation). #### Deprecations ##### Python API - `AclOptions` and `GpuOptions` are removed from version 1.7.0, and use `AscendDeviceInfo` and `GPUDeviceInfo` instead. - `register.declare_sevable` and `register.call_servable` are removed from version 1.7.0, and use `register.declare_model` and `register.add_stage` instead. - `register.call_preprocess`, `register.call_preprocess_pipeline`, `register.call_postprocess` and `register.call_postprocess_pipeline` are removed from version 1.7.0, and use `register.add_stage` instead. ### Contributors Thanks goes to these wonderful people: qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.6.0 Release Notes ### Major Features and Improvements - [STABLE] We can use existing interfaces(`decalre_model` and `add_stage`) that define single-model services to define multi-model composite services. - [STABLE] When the number of occupied devices is fixed, additional worker processes(using parameter `num_parallel_workers`) are supported to accelerate Python functions such as preprocessing and postprocessing, improving device utilization. - [STABLE] The interface `Model.call` is a stable feature, and can be used to define complex model invocation processes in the Serving server, such as looping and conditional branching. - [STABLE] The new interfaces `Context`, `CPUDeviceInfo`, `GPUDeviceInfo`, `AscendDeviceInfo` are provided to set user-defined device information. The original interfaces `GpuOptions` and `AclOptions` are deprecated. - [BETA] We support MindSpore Lite as the MindSpore Serving inference backend, for more detail see [MindSpore Serving backend](https://www.mindspore.cn/serving/docs/en/master/serving_install.html#installation). ### API Change #### New features ##### Python API ###### Multi-model composite services We can use existing interfaces(`decalre_model` and `add_stage`) that define single-model services to define multi-model composite services. For more detail, see [Services Composed of Multiple Models](https://www.mindspore.cn/serving/docs/en/master/serving_model.html#services-composed-of-multiple-models). ```python from mindspore_serving.server import register add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) sub_model = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names=["y"]) def add_sub_only_model(x1, x2, x3): # x1+x2-x3 y = register.add_stage(add_model, x1, x2, outputs_count=1) y = register.add_stage(sub_model, y, x3, outputs_count=1) return y ``` ###### Additional worker processes are supported to accelerate Python functions(preprocessing and postprocessing) Parameter `num_parallel_workers` in class `ServableStartConfig` is a stable feature. It's can be used to configure the total number of workers. The number of workers occupying devices is determined by the length of parameter `device_ids`. Additional worker processes use worker processes that occupy devices for model inference. For more detail, see [Multi-process Concurrency](https://www.mindspore.cn/serving/docs/en/master/serving_model.html#multi-process-concurrency). ```python class ServableStartConfig: def __init__(self, servable_directory, servable_name, device_ids, version_number=0, device_type=None, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM') ``` Start the serving server that contains the `resnet50` servable. The `resnet50` servable has four worker processes(`num_parallel_workers`), one of which occupies the device(`device_ids`). ```python import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) # Total 4 worker, one worker occupy device 0, the model inference tasks of other workers are forwarded to the worker # that occupies the device. config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="resnet50", device_ids=0, num_parallel_workers=4) server.start_servables(config) server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ``` ###### Model.call interface can be used to define complex model invocation processes The interface `Model.call` is a stable feature, and can be used to define complex model invocation processes in the Serving server, such as looping and conditional branching. ```python from mindspore_serving.server import register import numpy as np from .tokenizer import create_tokenizer, padding, END_TOKEN bert_model = register.declare_model(model_file="bert_poetry.mindir", model_format="MindIR") def calc_new_token(probas): ... return new_token_id tokenizer = create_tokenizer() def generate_sentence(input_sentence): input_token_ids = tokenizer.encode(input_sentence) target_ids = [] MAX_LEN = 64 while len(input_token_ids) + len(target_ids) < MAX_LEN: input_ids = padding(np.array(input_token_ids + target_ids), length=128) pad_mask = (input_ids != 0).astype(np.float32) probas = bert_model.call(input_ids, pad_mask) # call bert model to generate token id of new word new_token_id = calc_new_token(probas[len(input_token_ids)]) target_ids.append(new_token_id) if new_token_id == END_TOKEN: break output_sentence = tokenizer.decode(input_token_ids + target_ids) return output_sentence @register.register_method(output_names=["output_sentence"]) def predict(input_sentence): output_sentence = register.add_stage(generate_sentence, input_sentence, outputs_count=1) return output_sentence ``` #### Deprecations ##### Python API - The parameter `options` in `register.declare_model` is deprecated from version 1.6.0 and will be removed in a future version, use parameter `context` instead. - `AclOptions` and `GpuOptions` are deprecated from version 1.6.0 and will be removed in a future version, use `AscendDeviceInfo` and `GPUDeviceInfo` instead. ### Contributors Thanks goes to these wonderful people: qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.5.0 Release Notes ### Major Features and Improvements - [STABLE] To support multi-model orchestration (to be released in version 1.6), a set of APIs (`decalre_model` and `add_stage`) is added. The new APIs will be used in single-model and multi-model scenarios. The old APIs(`register.declare_servable`,`call_servable`,`call_preprocess`,`call_postprocess`) used in single-model scenarios are deprecated. - [BETA] When the number of occupied devices is fixed, additional worker processes are supported to accelerate Python functions such as preprocessing and postprocessing, improving device utilization. - [BETA]`Model.call` interface is added to support invoking models in Python functions. ### API Change #### API Incompatible Change ##### Python API ###### New set of APIs for single-model and multi-model scenarios To support multiple models(will be officially released in version 1.6), a set of APIs (`decalre_model` and `add_stage`) is added. The single-model and multi-model scenarios will use the same set of APIs. New APIs are recommended in single-model scenarios. Old APIs (`declare_servable`,`call_servable`,`call_preprocess`, `call_postprocess`) are deprecated.
1.4 1.5
```python from mindspore_serving.server import register register.declare_servable(servable_file="resnet.mindir", model_format="MindIR") def resnet_preprocess(image): .... def resnet_postprocess(scores): .... @register.register_method(output_names=["label"]) def predict(image): x = register.call_preprocess(resnet_preprocess, image) x = register.call_servable(x) x = register.call_postprocess(resnet_postprocess, x) return x ``` ```python from mindspore_serving.server import register resnet_model = register.declare_model(model_file="resnet.mindir", model_format="MindIR") def resnet_preprocess(image): .... def resnet_postprocess(scores): .... @register.register_method(output_names=["label"]) def predict(image): x = register.add_stage(resnet_preprocess, image, outputs_count=1) x = register.add_stage(resnet_model, x, outputs_count=1) x = register.add_stage(resnet_postprocess, x, outputs_count=1) return x ```
#### New features ##### Python API ###### Additional worker processes are supported to accelerate Python functions(preprocessing and postprocessing) Parameter `num_parallel_workers` is added to class `ServableStartConfig` to configure the total number of workers. The number of workers occupying devices is determined by the length of parameter `device_ids`. Additional worker processes use worker processes that occupy devices for model inference. ```python class ServableStartConfig: def __init__(self, servable_directory, servable_name, device_ids, version_number=0, device_type=None, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM') ``` Start the serving server that contains the `resnet50` servable. The `resnet50` servable has four worker processes(`num_parallel_workers`), one of which occupies the device(`device_ids`). ```python import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) # Total 4 worker, one worker occupy device 0, the model inference tasks of other workers are forwarded to the worker # that occupies the device. config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="resnet50", device_ids=0, num_parallel_workers=4) server.start_servables(config) server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ``` ###### Model.call interface is added to support invoking models in Python functions ```python from mindspore_serving.server import register add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def add_func(x1, x2, x3, x4): instances = [] instances.append((x1, x2)) instances.append((x3, x4)) output_instances = add_model.call(instances) # for multi instances y1 = output_instances[0][0] # instance 0 output 0 y2 = output_instances[1][0] # instance 1 output 0 y = add_model.call(y1, y2) # for single instance return y @register.register_method(output_names=["y"]) def predict(x1, x2, x3, x4): y = register.add_stage(add_func, x1, x2, x3, x4, outputs_count=1) return y ``` #### Deprecations ##### Python API - `register.declare_servable`,`call_servable`,`call_preprocess`,`call_postprocess`,`call_preprocess_pipeline` and`call_postprocess_pipeline` are now deprecated in favor of`register.declare_model` and`add_stage`, as shown above. Deprecated interfaces will be deleted in the future. - Beta interfaces`PipelineServable` and`register_pipeline` introduced in version 1.3 will be deleted and replaced with`Model.call`. ### Contributors Thanks goes to these wonderful people: chenweifeng, qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.3.0 Release Notes ### Major Features and Improvements - [STABLE] Enhances and simplifies the deployment and startup of single-chip models. Multiple models can be loaded by a single script. Each model can have multiple copies on multiple chips. Requests can be split and distributed to these copies for concurrent execution. - [STABLE] The `master`+`worker` interface of the Serving server is changed to the `server` interface. - [STABLE] The client and server support Unix Domain Socket-based gRPC communication. - [STABLE] gRPC and RESTful interfaces support TLS/SSL security authentication. - [STABLE] The MindIR encryption model is supported. - [BETA] Incremental inference models consisting of multiple static graphs are supported, including single-card models and distributed models. ### API Change #### API Incompatible Change ##### Python API ###### Enhances and simplifies the deployment and startup of single-chip models Multiple models can be loaded by a single script. Each model can have multiple copies on multiple chips. Requests can be split and distributed to these copies for concurrent execution. Interface `worker.start_servable_in_master` that can start only a single servables is changed to interface `server.start_servables` that can start multiple servables, and each servable can correspond to multiple copies. In addition, related interface `server.ServableStartConfig` is added.
1.2.x 1.3.0
```python import os import sys from mindspore_serving import master from mindspore_serving import worker def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) # deploy model add on device 0 worker.start_servable_in_master(servable_dir, "add", device_id=0) master.start_grpc_server("127.0.0.1", 5500) master.start_restful_server("127.0.0.1", 1500) if __name__ == "__main__": start() ``` ```python import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) # deploy model add on devices 0 and 1 add_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="add", device_ids=(0, 1)) # deploy model resnet50 on devices 2 and 3 resnet50_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="resnet50 ", device_ids=(2, 3)) server.start_servables(servable_configs=(add_config, resnet50_config)) server.start_grpc_server(address="127.0.0.1:5500") server.start_restful_server(address="127.0.0.1:1500") if __name__ == "__main__": start() ```
###### `mindspore_serving.worker.register` is updated to `mindspore_serving.server.register`
1.2.x 1.3.0
```python from mindspore_serving.worker import register ``` ```python from mindspore_serving.server import register ```
###### The gRPC and RESTful startup interfaces are updated. The namespace is changed from master to server, and the input parameters `ip` and `port` are changed to `address` only
1.2.x 1.3.0
```python from mindspore_serving import master master.start_grpc_server("127.0.0.1", 5500) master.start_restful_server("127.0.0.1", 1500) master.stop() ``` ```python from mindspore_serving import server server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") server.stop() ```
###### The name of the distributed interface function is simplified, and the namespace is changed from `worker` to `server` In `servable_config.py` of distributed model:
1.2.x 1.3.0
```python from mindspore_serving.worker import distributed distributed.declare_distributed_servable( rank_size=8, stage_size=1, with_batch_dim=False) ``` ```python from mindspore_serving.server import distributed distributed.declare_servable( rank_size=8, stage_size=1, with_batch_dim=False) ```
In startup script of distributed model:
1.2.x 1.3.0
```python import os import sys from mindspore_serving import master from mindspore_serving.worker import distributed def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) distributed.start_distributed_servable_in_master( servable_dir, "matmul", rank_table_json_file="rank_table_8pcs.json", version_number=1, worker_ip="127.0.0.1", worker_port=6200) master.start_grpc_server("127.0.0.1", 5500) master.start_restful_server("127.0.0.1", 1500) if __name__ == "__main__": start() ``` ```python import os import sys from mindspore_serving import server from mindspore_serving.server import distributed def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) distributed.start_servable( servable_dir, "matmul", rank_table_json_file="rank_table_8pcs.json", version_number=1, distributed_address="127.0.0.1:6200") server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ```
In agent startup script of distributed model:
1.2.x 1.3.0
```python from mindspore_serving.worker import distributed def start_agents(): """Start all the worker agents in current machine""" model_files = [] group_configs = [] for i in range(8): model_files.append(f"model/device{i}/matmul.mindir") group_configs.append(f"model/device{i}/group_config.pb") distributed.startup_worker_agents( worker_ip="127.0.0.1", worker_port=6200, model_files=model_files, group_config_files=group_configs) if __name__ == '__main__': start_agents() ``` ```python from mindspore_serving.server import distributed def start_agents(): """Start all the agents in current machine""" model_files = [] group_configs = [] for i in range(8): model_files.append(f"model/device{i}/matmul.mindir") group_configs.append(f"model/device{i}/group_config.pb") distributed.startup_agents( distributed_address="127.0.0.1:6200", model_files=model_files, group_config_files=group_configs) if __name__ == '__main__': start_agents() ```
###### The input parameters `ip`+`port` of the gRPC client are changed to `address` In addition to the {ip}:{port} address format, the Unix Domain Socket in the unix:{unix_domain_file_path} format is supported.
1.2.x 1.3.0
```python import numpy as np from mindspore_serving.client import Client def run_add_cast(): """invoke servable add method add_cast""" client = Client("localhost", 5500, "add", "add_cast") instances = [] x1 = np.ones((2, 2), np.int32) x2 = np.ones((2, 2), np.int32) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) if __name__ == '__main__': run_add_cast() ``` ```python import numpy as np from mindspore_serving.client import Client def run_add_cast(): """invoke servable add method add_cast""" client = Client("127.0.0.1:5500", "add", "add_cast") instances = [] x1 = np.ones((2, 2), np.int32) x2 = np.ones((2, 2), np.int32) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) if __name__ == '__main__': run_add_cast() ```
#### New features ##### Python API ###### Support Unix Domain Socket The Serving server: ```python import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="resnet50", device_ids=(0, 1)) server.start_servables(servable_configs=servable_config) server.start_grpc_server(address="unix:/tmp/serving_resnet50_test_temp_file") if __name__ == "__main__": start() ``` The Serving client: ```python import os from mindspore_serving.client import Client def run_classify_top1(): client = Client("unix:/tmp/serving_resnet50_test_temp_file", "resnet50", "classify_top1") instances = [] for path, _, file_list in os.walk("./test_image/"): for file_name in file_list: image_file = os.path.join(path, file_name) print(image_file) with open(image_file, "rb") as fp: instances.append({"image": fp.read()}) result = client.infer(instances) print(result) if __name__ == '__main__': run_classify_top1() ``` ###### Support SSL/TLS The Serving server: ```python import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="add", device_ids=(0, 1)) server.start_servables(servable_configs=servable_config) ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca=None, verify_client=False) server.start_grpc_server(address="127.0.0.1:5500", ssl_config=ssl_config) server.start_restful_server(address="127.0.0.1:1500", ssl_config=ssl_config) if __name__ == "__main__": start() ``` The gRPC Serving client: ```python from mindspore_serving.client import Client from mindspore_serving.client import SSLConfig import numpy as np def run_add_common(): """invoke Servable add method add_common""" ssl_config = SSLConfig(custom_ca="ca.crt") client = Client("localhost:5500", "add", "add_common", ssl_config=ssl_config) instances = [] # instance 1 x1 = np.asarray([[1, 1], [1, 1]]).astype(np.float32) x2 = np.asarray([[1, 1], [1, 1]]).astype(np.float32) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) if __name__ == '__main__': run_add_common() ``` The RESTful client ```shell >>> curl -X POST -d '{"instances":{"x1":[[1.0, 1.0], [1.0, 1.0]], "x2":[[1.0, 1.0], [1.0, 1.0]]}}' --insecure https://127.0.0.1:1500/model/add:add_common {"instances":[{"y":[[2.0,2.0],[2.0,2.0]]}]} ``` ###### Support encryption MindIR model ```python # export model import mindspore as ms # define add network # export encryption model ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add_enc', file_format='MINDIR', enc_key="asdfasdfasdfasgwegw12310".encode(), enc_mode='AES-GCM') ``` ```python # start Serving server import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="add", device_ids=(0, 1), dec_key='asdfasdfasdfasgwegw12310'.encode(), dec_mode='AES-CBC') server.start_servables(servable_configs=servable_config) server.start_grpc_server(address="127.0.0.1:5500") server.start_restful_server(address="127.0.0.1:1500") if __name__ == "__main__": start() ``` ###### [BETA] Support incremental inference models consisting of multiple static graphs A Incremental inference models can include a full input graph and an incremental input graph, and the Serving orchestrates the two static graphs using a user-defined Python script. For more details, please refer to [Serving pangu alpha](https://gitee.com/mindspore/models/tree/master/official/nlp/Pangu_alpha/serving_increment) . #### Deprecations ##### Python API - `mindspore_serving.master` and `mindspore_serving.worker` are now deprecated in favor of `mindspore_serving.server`, as shown above. Deprecated interfaces will be deleted in the next iteration. - The following interfaces are directly deleted. That is, workers of one serving server can no longer be deployed on othe machines. Users are no longer aware of workers at the interface layer. ```python mindspore_serving.worker.start_servable mindspore_serving.worker.distributed.start_distributed_servable mindspore_serving.master.start_master_server ``` ### Contributors Thanks goes to these wonderful people: chenweifeng, qinzheng, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.2.0 Release Notes ### Major Features and Improvements - [STABLE] Support distributed inference, it needs to cooperate with distributed training to export distributed models for super-large-scale neural network parameters(Ascend 910). - [STABLE] Support GPU platform, Serving worker nodes can be deployer on Nvidia GPU, Ascend 310 and Ascend 910. - This release is based on MindSpore version 1.2.0 - Support Python 3.8 and 3.9. ### API Change #### API Incompatible Change ##### Python API Support deployment of distributed model, refer to [distributed inference tutorial](https://www.mindspore.cn/serving/docs/en/master/serving_distributed_example.html) for related API. #### Deprecations ##### Python API ### Bug Fixes ### Contributors Thanks goes to these wonderful people: chenweifeng, qinzheng, xujincai, xuyongfei, zhangyinxia, zhoufeng. Contributions of any kind are welcome! ## MindSpore Serving 1.1.1 Release Notes ### Major Features and Improvements - Adapts new C++ inference interface for MindSpore version 1.1.1. ### Bug fixes - [BUGFIX] Fix bug in transforming result of type int16 in python Client. - [BUGFIX] Fix bytes type misidentified as str type after python preprocess and postprocess. - [BUGFIX] Fix bug releasing C++ tensor data when it's wrapped as numpy object sometimes. - [BUGFIX] Update RuntimeError to warning log when check Ascend environment failed. ## MindSpore Serving 1.1.0 Release Notes ### Major Features and Improvements - [STABLE] Support gRPC and RESTful API. - [STABLE] Support simple Python API for Client and Server. - [STABLE] Support Model configuration,User can customize preprocessing & postprocessing for model. - [STABLE] Support multiple models,Multiple models can run simultaneously. - [STABLE] Support Model batching,Multiple instances will be split and combined to meet the batch size requirements of the model. - This release is based on MindSpore version 1.1.0 ### Bug Fixes ### Contributors ================================================ FILE: RELEASE_CN.md ================================================ # MindSpore Serving Release Notes [View English](./RELEASE.md) ## MindSpore Serving 2.0.2 Release Notes ### 主要特性和增强 - 配套MindSpore 2.2.0版本接口。 - 修复第三方库OpenSSL漏洞CVE-2023-3446、CVE-2023-4807。 ### 贡献者 感谢以下人员做出的贡献: qinzheng, xuyongfei, zhangyinxia, zhoufeng. 欢迎以任何形式对项目提供贡献! ## MindSpore Serving 2.0.0 Release Notes ### 主要特性和增强 - 配套MindSpore 2.0.0rc1版本接口。 - 修复第三方库OpenSSL漏洞CVE-2022-4304、CVE-2022-4450、CVE-2022-4450、CVE-2023-0286、CVE-2023-0464、CVE-2023-0465、CVE-2023-0466。 ### 贡献者 感谢以下人员做出的贡献: qinzheng, xuyongfei, zhangyinxia, zhoufeng. 欢迎以任何形式对项目提供贡献! ## MindSpore Serving 1.8.0 Release Notes ### 主要特性和增强 - [STABLE] Serving部署流水线并行的大模型时,支持流水线并行处理多个推理实例。 ### 贡献者 感谢以下人员做出的贡献: qinzheng, xuyongfei, zhangyinxia, zhoufeng. 欢迎以任何形式对项目提供贡献! ## MindSpore Serving 1.7.0 Release Notes ### 主要特性和增强 - [DEMO] Ascend 310P可以作为MindSpore Serving的硬件后端,详情可参考[MindSpore Serving后端](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_install.html#installation)。 - [DEMO] MindSpore Lite作为MindSpore Serving推理后端时,支持MindIR模型格式,详情可参考[MindSpore Serving后端](https://www.mindspore.cn/serving/docs/zh-CN/master/serving_install.html#installation)。 #### 不建议使用 ##### Python API - `AclOptions`和 `GpuOptions`从1.7.0版本开始被移除,使用 `AscendDeviceInfo`和 `GPUDeviceInfo`替代。 - `register.declare_sevable`和 `register.call_servable`从1.7.0版本开始被移除,使用 `register.declare_model`和 `register.add_stage`替代。 - `register.call_preprocess`,`register.call_preprocess_pipeline`,`register.call_postprocess`和 `register.call_postprocess_pipeline`从1.7.0版本开始被移除,使用 `register.add_stage`替代。 ### 贡献者 感谢以下人员做出的贡献: qinzheng, xuyongfei, zhangyinxia, zhoufeng. 欢迎以任何形式对项目提供贡献! ================================================ FILE: Third_Party_Open_Source_Software_Notice ================================================ OPEN SOURCE SOFTWARE NOTICE Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country. Warranty Disclaimer THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS. Copyright Notice and License Texts Software: Eigen 3.3.7 Copyright notice: Copyright (C) 2014 Benoit Steiner Copyright (C) 2013 Christian Seiler Copyright (C) 2015 Eugene Brevdo Copyright (C) 2014-2015 Benoit Steiner Copyright (C) 2015 Navdeep Jaitly Copyright (C) 2014 Eric Martin Copyright (C) 2015 Benoit Steiner Copyright (C) 2016 Rasmus Munk Larsen Copyright (C) 2016 Benoit Steiner Copyright (C) 2015 Jianwei Cui Copyright (C) 2016 Eugene Brevdo Copyright (C) 2015 Ke Yang Copyright (C) 2016 Mehdi Goli, Codeplay Software Ltd Copyright (C) 2014 Navdeep Jaitly Copyright (C) 2016 Igor Babuschkin Copyright (C) 2016 Dmitry Vyukov Copyright (C) EDF R&D, lun sep 30 14:23:30 CEST 2002 Copyright (C) 2008 Gael Guennebaud Copyright (C) EDF R&D, lun sep 30 14:23:31 CEST 2002 Copyright (C) 2008-2010 Gael Guennebaud Copyright (C) 2008-2016 Gael Guennebaud Copyright (C) 2009 Mark Borgerding mark a borgerding net Copyright (C) 2008-2009 Gael Guennebaud Copyright (C) 2013 Desire Nuentsa Copyright (C) 2013 Gael Guennebaud Copyright (C) 2011 Gael Guennebaud Copyright (C) 2012 Desire NUENTSA WAKAM Copyright (C) 2009 Benoit Jacob Copyright (C) 2009 Gael Guennebaud Copyright (C) 2006-2010 Benoit Jacob Copyright (C) 2006-2008 Benoit Jacob Copyright (C) EDF R&D, lun sep 30 14:23:28 CEST 2002 Copyright (C) 2010 Manuel Yguel Copyright (C) 2009 Claire Maurice Copyright (C) 2010,2012 Jitse Niesen Copyright (c) 2011, Intel Corporation. All rights reserved. Copyright (C) 2012-2016 Gael Guennebaud Copyright (C) 2016 Tobias Wood Copyright (C) 2010 Jitse Niesen Copyright (C) 2012 Alexey Korepanov Copyright (C) 2010 Vincent Lejeune Copyright (C) 2010 Gael Guennebaud Copyright (C) 2010 Benoit Jacob Copyright (C) 2017 Gael Guennebaud Copyright (C) 2009-2010 Gael Guennebaud Copyright (C) 2008 Benoit Jacob Copyright (C) 2009 Mathieu Gautier Copyright (C) 2010 Hauke Heibel Copyright (C) 2009 Hauke Heibel Copyright (C) 2008-2015 Gael Guennebaud Copyright (C) EDF R&D, mar déc 3 18:59:36 CET 2002 Copyright (C) EDF R&D, lun sep 30 14:23:17 CEST 2002 Copyright (C) EDF R&D, mar déc 3 18:59:35 CET 2002 Copyright (C) 2016 Konstantinos Margaritis Copyright (C) 2007 Julien Pommier Copyright (C) 2008-2011 Gael Guennebaud Copyright (C) 2009 Keir Mierle Copyright (C) 2011 Timothy E. Holy Copyright (C) 2009 Hauke Heibel Copyright (C) 2012 Desire Nuentsa Copyright (C) 2014 Gael Guennebaud Copyright (C) 2015 Tal Hadad @copyright (c) 2009-2014 The University of Tennessee and The University of Tennessee Research Foundation. @copyright (c) 2012-2016 Inria. All rights reserved. @copyright (c) 2012-2014 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, Univ. Bordeaux. All rights reserved. Copyright 2007-2009 Kitware, Inc. Copyright 2012-2013 Inria Copyright 2012-2013 Emmanuel Agullo Copyright 2012-2013 Mathieu Faverge Copyright 2012 Cedric Castagnede Copyright 2013-2016 Florent Pruvost Copyright 2016 Codeplay Software Ltd. Copyright (c) 2006, 2007 Montel Laurent, Copyright (c) 2008, 2009 Gael Guennebaud, Copyright (c) 2009 Boudewijn Rempt @copyright (c) 2012-2014 Inria. All rights reserved. Copyright 2013 Florent Pruvost Copyright (c) 2010 Jitse Niesen, Copyright (C) 2009 Benjamin Schindler Copyright (C) 2016 Pedro Gonnet (pedro.gonnet@gmail.com) Copyright (C) 2016 Benoit Steiner (benoit.steiner.goog@gmail.com) Copyright (C) 2009 Thomas Capricelli Copyright (C) 2012-2013 Desire Nuentsa Copyright (C) 2012-2014 Gael Guennebaud Copyright Jorge More - Argonne National Laboratory Copyright Burt Garbow - Argonne National Laboratory Copyright Ken Hillstrom - Argonne National Laboratory Copyright (C) 2009 Ilya Baran Copyright (c) 2010, Intel Corp. Copyright (C) 2009-2010 Benoit Jacob Copyright (C) 2013-2016 Gael Guennebaud Copyright (C) 2013 Gauthier Brun Copyright (C) 2013 Nicolas Carre Copyright (C) 2013 Jean Ceccato Copyright (C) 2013 Pierre Zoppitelli Copyright (C) 2013 Jitse Niesen Copyright (C) 2014-2017 Gael Guennebaud Copyright (C) 2013-2014 Gael Guennebaud Copyright (C) 2011-2014 Gael Guennebaud Copyright (C) 2012 Désiré Nuentsa-Wakam Copyright (C) 2015 Gael Guennebaud Copyright (C) 2012 Gael Guennebaud Copyright (c) 1994 by Xerox Corporation. All rights reserved. Copyright (C) 2001 Intel Corporation Copyright (c) 2001 Intel Corporation. Copyright (C) 2009 Gael Guennebaud Copyright (C) 2013 Christoph Hertzberg Copyright (C) 2015 Eugene Brevdo Copyright (C) 2016 Mehdi Goli Codeplay Software Ltd. Ralph Potter Codeplay Software Ltd. Luke Iwanski Codeplay Software Ltd. Copyright (C) 2014 Jianwei Cui Copyright (C) 2015 Vijay Vasudevan Copyright (C) 2015 Mehdi Goli Codeplay Software Ltd. Ralph Potter Codeplay Software Ltd. Luke Iwanski Codeplay Software Ltd. Copyright (C) 2014 Navdeep Jaitly Copyright (C) 2011 Gael Guennebaud Copyright (C) 2012 desire Nuentsa Copyright (C) 2012 Kolja Brix Copyright (C) 2011 Kolja Brix Copyright (C) 2011 Andreas Platen Copyright (C) 2012 Chen-Pang He Copyright (C) 2009 Jitse Niesen Copyright (C) 2009-2011 Jitse Niesen Copyright (C) 2012, 2013 Chen-Pang He Copyright (C) 2011 Jitse Niesen Copyright (C) 2012 Giacomo Po Copyright (C) 2008-2010 Gael Guennebaud Copyright (C) 2016 Gael Guennebaud Copyright (C) 2010-2011 Hauke Heibel Copyright (C) 2012 David Harmon Copyright (C) 2007-2009 Benoit Jacob Copyright (C) 2007-2010 Benoit Jacob Copyright (C) 2008-2009 Benoit Jacob Copyright (C) 2009 Kenneth Riddile Copyright (C) 2010 Thomas Capricelli Copyright (C) 2013 Pavel Holoborodko Copyright (C) EDF R&D, lun sep 30 14:23:16 CEST 2002 Copyright (C) EDF R&D, mar déc 3 18:59:37 CET 2002 Copyright (C) 2006-2009 Benoit Jacob Copyright (C) 2008-2010 Benoit Jacob Copyright (c) 2008-2015 Pavel Holoborodko Copyright (C) 20010-2011 Hauke Heibel Copyright (c) 2006, Montel Laurent, Copyright (c) 2007, Allen Winter, Copyright (c) 2007, Alexander Neundorf, Copyright (C) 2008 Guillaume Saupin Copyright (C) 2008-2009 Guillaume Saupin Copyright (C) 2009 Guillaume Saupin Copyright (C) 2010-2016 Konstantinos Margaritis Copyright (C) 2008-2016 Konstantinos Margaritis Copyright (C) 2014 Benoit Steiner (benoit.steiner.goog@gmail.com) Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) Copyright (c) Fabian Giesen, 2016 Copyright (C) 2010 Konstantinos Margaritis Copyright (C) 2007 Michael Olbrich Copyright (C) 2011 Benoit Jacob Copyright (C) 2011-2012 Jitse Niesen Copyright (C) 2016 Rasmus Munk Larsen (rmlarsen@google.com) Copyright (C) 2008-2014 Gael Guennebaud Copyright (C) 2010-2013 Hauke Heibel Copyright (C) 2006-2008, 2010 Benoit Jacob Copyright (C) 2010-2016 Gael Guennebaud Copyright (C) 2009-2015 Gael Guennebaud Copyright (C) 2009 Ricard Marxer Copyright (C) 2009-2014 Gael Guennebaud Copyright (C) 2010-2011 Gael Guennebaud Copyright (C) 2009 Rohit Garg Copyright (c) 2006, Timothy A. Davis. Copyright (c) 1998-2003 by the University of Florida. Copyright (C) 2012 Désiré Nuentsa-Wakam Copyright (C) 2008-2012 Gael Guennebaud LDL Copyright (c) 2005 by Timothy A. Davis. All Rights Reserved. Copyright (C) 2010 Daniel Lowengrub Copyright (C) EDF R&D, lun sep 30 14:23:20 CEST 2002 Copyright (C) EDF R&D, lun sep 30 14:23:19 CEST 2002 Copyright (C) 2009, 2010, 2013 Jitse Niesen Copyright (C) 2011, 2013 Chen-Pang He Copyright (C) 2009-2011, 2013 Jitse Niesen Copyright (C) 2011, 2013 Jitse Niesen Copyright (C) 2011 Chen-Pang He Copyright (C) 2010, 2013 Jitse Niesen Copyright (C) 2010-2014 Gael Guennebaud Copyright (C) 2012 The Android Open Source Project (C) Desire NUENTSA WAKAM, INRIA Copyright (C) EDF R&D, lun sep 30 14:23:18 CEST 2002 Copyright (C) 2012 Keir Mierle Copyright (C) 1989, 1991 Free Software Foundation, Inc. Copyright (C) EDF R&D, lun sep 30 14:23:23 CEST 2002 Copyright (C) EDF R&D, lun sep 30 14:23:24 CEST 2002 Copyright (C) EDF R&D, lun sep 30 14:23:27 CEST 2002 Copyright (C) 2007 Free Software Foundation, Inc. Copyright (C) 1991, 1999 Free Software Foundation, Inc. Copyright (C) 2015 Benoit Jacob Geometric Tools, LLC Copyright (c) 1998-2010 Copyright (C) EDF R&D, lun sep 30 14:23:15 CEST 2002 Copyright (C) 2002-2007 Yves Renard Copyright (C) 2012, 2014 Kolja Brix Copyright (C) 1997-2001 Andrew Lumsdaine Lie-Quan Lee Copyright (C) 2012 Desire NUENTSA WAKAM Copyright (C) 2013 Hauke Heibel Copyright (C) 2010-2011 Jitse Niesen Intel Copyright (C) .... Copyright (C) 2010-2017 Gael Guennebaud Copyright (C) 20013 Gael Guennebaud Copyright (C) 2008 Daniel Gomez Ferro Copyright (C) 2013 Désiré Nuentsa-Wakam Copyright (C) 2011-2015 Gael Guennebaud Copyright (C) 20015 Gael Guennebaud Copyright (C) 2014-2015 Gael Guennebaud License: Mozilla Public License (MPL) V2.0 Mozilla Public License Version 2.0 1. Definitions 1.1. “Contributor” means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. “Contributor Version” means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor’s Contribution. 1.3. “Contribution” means Covered Software of a particular Contributor. 1.4. “Covered Software” means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. “Incompatible With Secondary Licenses” means that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. “Executable Form” means any form of the work other than Source Code Form. 1.7. “Larger Work” means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. “License” means this document. 1.9. “Licensable” means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. “Modifications” means any of the following: any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or any new file in Source Code Form that contains any Covered Software. 1.11. “Patent Claims” of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. “Secondary License” means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. “Source Code Form” means the form of the work preferred for making modifications. 1.14. “You” (or “Your”) means an individual or a legal entity exercising rights under this License. For legal entities, “You” includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, “control” means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: for any code that a Contributor has removed from Covered Software; or for infringements caused by: (i) Your and any other third party’s modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients’ rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients’ rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. 6. Disclaimer of Warranty Covered Software is provided under this License on an “as is” basis, without warranty of any kind, either expressed, implied, or statutory, including, without limitation, warranties that the Covered Software is free of defects, merchantable, fit for a particular purpose or non-infringing. The entire risk as to the quality and performance of the Covered Software is with You. Should any Covered Software prove defective in any respect, You (not any Contributor) assume the cost of any necessary servicing, repair, or correction. This disclaimer of warranty constitutes an essential part of this License. No use of any Covered Software is authorized under this License except under this disclaimer. 7. Limitation of Liability Under no circumstances and under no legal theory, whether tort (including negligence), contract, or otherwise, shall any Contributor, or anyone who distributes Covered Software as permitted above, be liable to You for any direct, indirect, special, incidental, or consequential damages of any character including, without limitation, damages for lost profits, loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses, even if such party shall have been informed of the possibility of such damages. This limitation of liability shall not apply to liability for death or personal injury resulting from such party’s negligence to the extent applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion or limitation of incidental or consequential damages, so this exclusion and limitation may not apply to You. 8. Litigation Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party’s ability to bring cross-claims or counter-claims. 9. Miscellaneous This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - “Incompatible With Secondary Licenses” Notice This Source Code Form is “Incompatible With Secondary Licenses”, as defined by the Mozilla Public License, v. 2.0. Software: JSON for Modern C++ 3.6.1 Copyright notice: Copyright 2015 Google Inc. All rights reserved. Copyright 2018 Google Inc. All rights reserved. Copyright 2016 Ismael Jimenez Martinez. All rights reserved. Copyright 2017 Roman Lebedev. All rights reserved. Copyright (c) 2012 Two Blue Cubes Ltd. All rights reserved. Copyright (c) 2015 Max Woolf Copyright 2014 The Authors Copyright (c) 2016 Nicolas Seriot Copyright (c) 2015-2017 Niels Lohmann. Copyright (c) 2015-2017 Niels Lohmann Copyright (c) 2013-2019 Niels Lohmann . Copyright (c) 2018 Vitaliy Manushkin . Copyright (c) 2012, Erik Edlund Copyright (c) 2013-2019 Niels Lohmann Copyright 2013-2019 [Niels Lohmann](http:nlohmann.me) Copyright (c) 2009 Google Inc. All rights reserved. Copyright (C) 2009 Google Inc. License: MIT License The MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Software: re2 20191201 Copyright notice: Copyright (C) 2005 Free Software Foundation, Inc. Copyright (C) 2007 Free Software Foundation, Inc. Copyright (C) 2009 Free Software Foundation, Inc. Copyright (C) 2009 The Android Open Source Project Copyright (c) 2002 by Lucent Technologies. Copyright (c) 2009 The RE2 Authors. Copyright 1999-2005 The RE2 Authors. All Rights Reserved. Copyright 2001-2010 The RE2 Authors. All Rights Reserved. Copyright 2002-2009 The RE2 Authors. All Rights Reserved. Copyright 2003-2009 Google Inc. Copyright 2003-2009 The RE2 Authors. All Rights Reserved. Copyright 2003-2010 Google Inc. All Rights Reserved. Copyright 2004 The RE2 Authors. All Rights Reserved. Copyright 2005 The RE2 Authors. All Rights Reserved. Copyright 2006 The RE2 Authors. All Rights Reserved. Copyright 2006-2007 The RE2 Authors. All Rights Reserved. Copyright 2006-2008 The RE2 Authors. All Rights Reserved. Copyright 2007 The RE2 Authors. All Rights Reserved. Copyright 2008 The RE2 Authors. All Rights Reserved. Copyright 2009 The RE2 Authors. All Rights Reserved. Copyright 2010 The RE2 Authors. All Rights Reserved. Copyright 2012 The Go Authors. Copyright 2015 The RE2 Authors. All Rights Reserved. Copyright 2016 The RE2 Authors. All Rights Reserved. Copyright 2018 The RE2 Authors. All Rights Reserved. License: BSD-3 with additional clause Most files in this release are marked with the copyrights of the organizations who have edited them. The copyrights below are in no particular order and generally reflect members of the Open MPI core team who have contributed code to this release. The copyrights for code used under license from other parties are included in the corresponding files. Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana University Research and Technology Corporation. All rights reserved. Copyright (c) 2004-2017 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. Copyright (c) 2004-2010 High Performance Computing Center Stuttgart, University of Stuttgart. All rights reserved. Copyright (c) 2004-2008 The Regents of the University of California. All rights reserved. Copyright (c) 2006-2017 Los Alamos National Security, LLC. All rights reserved. Copyright (c) 2006-2017 Cisco Systems, Inc. All rights reserved. Copyright (c) 2006-2010 Voltaire, Inc. All rights reserved. Copyright (c) 2006-2017 Sandia National Laboratories. All rights reserved. Copyright (c) 2006-2010 Sun Microsystems, Inc. All rights reserved. Use is subject to license terms. Copyright (c) 2006-2017 The University of Houston. All rights reserved. Copyright (c) 2006-2009 Myricom, Inc. All rights reserved. Copyright (c) 2007-2017 UT-Battelle, LLC. All rights reserved. Copyright (c) 2007-2017 IBM Corporation. All rights reserved. Copyright (c) 1998-2005 Forschungszentrum Juelich, Juelich Supercomputing Centre, Federal Republic of Germany Copyright (c) 2005-2008 ZIH, TU Dresden, Federal Republic of Germany Copyright (c) 2007 Evergrid, Inc. All rights reserved. Copyright (c) 2008 Chelsio, Inc. All rights reserved. Copyright (c) 2008-2009 Institut National de Recherche en Informatique. All rights reserved. Copyright (c) 2007 Lawrence Livermore National Security, LLC. All rights reserved. Copyright (c) 2007-2017 Mellanox Technologies. All rights reserved. Copyright (c) 2006-2010 QLogic Corporation. All rights reserved. Copyright (c) 2008-2017 Oak Ridge National Labs. All rights reserved. Copyright (c) 2006-2012 Oracle and/or its affiliates. All rights reserved. Copyright (c) 2009-2015 Bull SAS. All rights reserved. Copyright (c) 2010 ARM ltd. All rights reserved. Copyright (c) 2016 ARM, Inc. All rights reserved. Copyright (c) 2010-2011 Alex Brick . All rights reserved. Copyright (c) 2012 The University of Wisconsin-La Crosse. All rights reserved. Copyright (c) 2013-2016 Intel, Inc. All rights reserved. Copyright (c) 2011-2017 NVIDIA Corporation. All rights reserved. Copyright (c) 2016 Broadcom Limited. All rights reserved. Copyright (c) 2011-2017 Fujitsu Limited. All rights reserved. Copyright (c) 2014-2015 Hewlett-Packard Development Company, LP. All rights reserved. Copyright (c) 2013-2017 Research Organization for Information Science (RIST). All rights reserved. Copyright (c) 2017 Amazon.com, Inc. or its affiliates. All Rights reserved. $COPYRIGHT$ Additional copyrights may follow $HEADER$ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. - Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. The copyright holders provide no reassurances that the source code provided does not infringe any patent, copyright, or any other intellectual property rights of third parties. The copyright holders disclaim any liability to any recipient for claims brought against recipient by any third party for infringement of that parties intellectual property rights. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Software: pip 20.0.2 Copyright notice: Copyright (c) 2010 ActiveState Software Inc. Copyright (c) 1991-2014 Unicode, Inc. All rights reserved. Copyright (C) 2013 Vinay Sajip. Copyright (C) 2013-2015 Vinay Sajip. Copyright 2012 Facebook Copyright (c) 2017 Thomas Kluyver Copyright (c) 2008-2019 Andrey Petrov and contributors (see CONTRIBUTORS.txt) Copyright (c) 2001-2014 Python Software Foundation; All Rights Reserved Copyright (C) 2012-2015 Vinay Sajip. Copyright (C) 1991, 1999 Free Software Foundation, Inc. Copyright (c) 2010-2020 Benjamin Peterson Copyright (c) 2010-2019 Benjamin Peterson Copyright (C) 2016 Jason R Coombs Copyright (C) 2012 The Python Software Foundation. i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. Copyright (c) 2005-2010 ActiveState Software Inc. Copyright (c) 2006-2013 James Graham and other contributors Copyright (c) 2013 Eddy Petrișor Copyright 2015 Eric Larson Copyright (c) 2013-2018, Kim Davies. All rights reserved. Copyright (C) 2002 Lars Gustaebel Copyright (C) 2008-2011 INADA Naoki Copyright (c) 2008-2016 The pip developers (see AUTHORS.txt file) Copyright 2018 Kenneth Reitz copyright = '2008-2017, PyPA' Copyright (c) Donald Stufft and individual contributors. Copyright (C) 2012-2019 Vinay Sajip. Copyright 2015,2016,2017 Nir Cohen Copyright (c) 2008-2019 The pip developers (see AUTHORS.txt file) Copyright (C) 2013-2017 Vinay Sajip. Copyright 2007 Google Inc. License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet copyright = "Copyright 2014-2019 %s" % author Copyright (c) 2015-2016 Will Bond Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Copyright (c) 2010 Jonathan Hartley All rights reserved. i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. Copyright (C) 2012-2013 Python Software Foundation. Copyright 2013-2014 Ray Holder Copyright (C) 2012-2017 Vinay Sajip. copyright = 'Copyright 2019 Kenneth Reitz' Copyright (c) 2012 by Simon Sapin. i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. Copyright (c) 2003-2019 Paul T. McGuire Copyright (c) 2012 Giorgos Verigakis Copyright (C) 2012-2017 The Python Software Foundation. License: Copyright (c) 2008-2019 The pip developers (see AUTHORS.txt file) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Software: pytest 1.6.0 Copyright notice: copyright = "2015–2020, holger krekel and pytest-dev team" If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. Copyright (c) 2010 by Armin Ronacher. Copyright (c) 2004-2020 Holger Krekel and others Copyright Holger Krekel and others, 2004-2020. epubcopyright = "2013-2020, holger krekel et alii" License: The MIT License (MIT) Copyright (c) 2004-2020 Holger Krekel and others Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Software: googletest 1.8.1 Copyright notice: Copyright 2009, Google Inc. Copyright 2008, Google Inc. Copyright 2007 Google Inc. Copyright 2007, Google Inc. Copyright 2013, Google Inc. Copyright 2015, Google Inc. Copyright 2005, Google Inc. Copyright 2008 Google Inc. Copyright 2006, Google Inc. Copyright 2009 Google Inc. All Rights Reserved. Copyright 2013 Google Inc. All Rights Reserved. Copyright 2017 Google Inc. Copyright 2007 Neal Norwitz Copyright 2008 Google Inc. All Rights Reserved. Copyright 2009 Neal Norwitz All Rights Reserved. Copyright 2003 Google Inc. Copyright 2009 Google Inc. Copyright 2008 Google Inc. All Rights Reserved. Copyright [2007] Neal Norwitz Portions Copyright [2007] Google Inc. Copyright 2010 Google Inc. All Rights Reserved. Copyright 2010, Google Inc. Copyright 2005 Google Inc. All Rights Reserved. Copyright 2018, Google Inc. Copyright 2003, Google Inc. Copyright 2009 Google Inc. All rights reserved. Copyright 2015 Google Inc. All rights reserved. Copyright 2009 Google Inc. All rights reserved. Copyright 2018 Google LLC. All rights reserved. Copyright 2018, Google LLC. License: BSD 3-Clause License Copyright 2008, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Software: glog 0.4.0 Copyright notice: Copyright (c) 1999, Google Inc. Copyright (c) 2007, Google Inc. Copyright (c) 2006, Google Inc. Copyright (c) 2003, Google Inc. Copyright (c) 1999, 2007, Google Inc. Copyright (c) 2008, Google Inc. Copyright (c) 2009, Google Inc. Copyright (c) 2002, Google Inc. Copyright (c) 2000 - 2007, Google Inc. Copyright (c) 2005 - 2007, Google Inc. Copyright (c) 2004, Google Inc. Copyright (c) 2003-2008, Jouni Malinen and contributors License: BSD 3-Clause License Copyright (c) 2008, Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. A function gettimeofday in utilities.cc is based on http://www.google.com/codesearch/p?hl=en#dR3YEbitojA/COPYING&q=GetSystemTimeAsFileTime%20license:bsd The license of this code is: Copyright (c) 2003-2008, Jouni Malinen and contributors All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name(s) of the above-listed copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Software: pybind11 2.4.3 Copyright notice: Copyright (c) 2015 Wenzel Jakob Copyright (c) 2016 Wenzel Jakob Copyright (c) 2016 Trent Houliston and Wenzel Jakob Copyright (c) 2017 Wenzel Jakob Copyright (c) 2017 Jason Rhinelander Copyright (c) 2016 Klemens Morgenstern and Copyright (c) 2017 Henry F. Schreiner Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob Copyright (c) 2016 Wenzel Jakob , All rights reserved. Copyright (c) 2016 Jason Rhinelander Copyright (c) 2019 Google LLC Copyright (c) 2019 Google Inc. Copyright (c) 2016 Ben North Copyright (c) 2016 Klemens D. Morgenstern Copyright (c) 2016 Pim Schellart Copyright (c) 2017 Borja Zarco (Google LLC) Copyright (c) 2016 Ivan Smirnov Copyright (c) 2016 Ivan Smirnov Copyright (c) 2016 Sergey Lyskov Copyright (c) 2018 Hudson River Trading LLC Copyright (c) 2019 Roland Dreier Copyright (c) 2006, 2007 Montel Laurent, Copyright (c) 2008, 2009 Gael Guennebaud, Copyright (c) 2009 Benoit Jacob Copyright 2001-2009 Kitware, Inc. Copyright 2012 Continuum Analytics, Inc. Copyright (c) 2007-2012 University of Illinois at Urbana-Champaign. License:BSD 3-Clause License Copyright (c) 2016 Wenzel Jakob , All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Please also refer to the file CONTRIBUTING.md, which clarifies licensing of external contributions to this project including patches, pull requests, etc. Software: pybind11 2.6.1 Copyright notice: Copyright (c) 2016 Wenzel Jakob , All rights reserved. Copyright (c) 2016 Ben North Copyright (c) 2017 Wenzel Jakob Copyright 2012 Continuum Analytics, Inc. Copyright 2001-2009 Kitware, Inc. Copyright (c) 2016 Ivan Smirnov Copyright (c) 2017 Borja Zarco (Google LLC) copyright = "2017, Wenzel Jakob" Copyright (c) 2016 Jason Rhinelander Copyright (c) 2016 Trent Houliston and Wenzel Jakob Copyright (c) 2016 Wenzel Jakob Copyright (c) 2017 Jason Rhinelander Copyright (c) 2006, 2007 Montel Laurent, Copyright (c) 2008, 2009 Gael Guennebaud, Copyright (c) 2016 Klemens Morgenstern and Wenzel Jakob Copyright (c) 2020 Wenzel Jakob Copyright (c) 2019 Google Inc. Copyright (c) 2019 Roland Dreier Copyright (c) 2018 Hudson River Trading LLC Copyright (c) 2019 Google LLC Copyright (c) 2015 Wenzel Jakob Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob Copyright (c) 2016 Ivan Smirnov Copyright (c) 2016 Klemens D. Morgenstern Copyright (c) 2009 Benoit Jacob Copyright (c) 2016 Pim Schellart Copyright (c) 2020 Wenzel Jakob and Henry Schreiner Copyright (c) 2016 Sergey Lyskov Copyright (c) 2017 Henry F. Schreiner Copyright (c) 2016 Wenzel Jakob , All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Please also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of external contributions to this project including patches, pull requests, etc. Software: google/protobuf 3.13.0 Copyright 2008 Google Inc. All rights reserved. Copyright 2008 Google Inc. All rights reserved. Copyright 2007-2010 Baptiste Lepilleur Distributed under MIT license, or public domain if desired and recognized in your jurisdiction. Copyright 2007 Google Inc. All Rights Reserved. Copyright 2012 Google Inc. All rights reserved. Copyright 2014 Google Inc. All rights reserved. Copyright 2019 Google Inc. All rights reserved. Copyright 2008 Google Inc. All Rights Reserved. copyright = u"2008, Google LLC" Copyright 2017 Google Inc. All rights reserved. Copyright 2008 Google Inc. Copyright 2015 Google Inc. All rights reserved. Copyright 2019 Google Inc. All rights reserved. Copyright (c) 2006, Google Inc. Copyright (c) 2007-2010 Baptiste Lepilleur Copyright 2017 Google Inc. All rights reserved. Copyright 2015 Google Inc. All rights reserved. Copyright 2018 Google Inc. All rights reserved. Copyright 2009 Google Inc. All rights reserved. Copyright 2007-2011 Baptiste Lepilleur Distributed under MIT license, or public domain if desired and recognized in your jurisdiction. Copyright 2011 Baptiste Lepilleur Distributed under MIT license, or public domain if desired and recognized in your jurisdiction. Copyright 2015, Google Inc. Copyright 2019 Google LLC. All rights reserved. Copyright 2016 Google Inc. All rights reserved. Copyright 2005 Google Inc. Copyright 2016 Google Inc. All rights reserved. License: BSD 3-Clause License Copyright 2008 Google Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Code generated by the Protocol Buffer compiler is owned by the owner of the input file used when generating it. This code is not standalone and requires a support library to be linked with it. This support library is itself covered by the above license. Software: libevent 2.1.12 Copyright notice: Copyright (C) 1998 - 2012, Daniel Stenberg, , et al. COPYRIGHT AND PERMISSION NOTICE Copyright (c) 1996 - 2013, Daniel Stenberg, . Copyright (C) 2012, iSEC Partners. Copyright (c) 1987, 1993, 1994, 1995 Copyright (c) 1987, 1993, 1994, 1996 Copyright 2002 Niels Provos Copyright (c) 2007-2012 Niels Provos and Nick Mathewson Copyright (c) 2000-2007 Niels Provos Copyright (c) 2007-2012 Niels Provos, Nick Mathewson Copyright (c) 2009-2012 Niels Provos and Nick Mathewson Copyright (c) 2006-2007 Niels Provos Copyright (c) 2008-2012 Niels Provos and Nick Mathewson Copyright (c) 1991, 1993 Copyright (c) 2009, Michihiro NAKAJIMA Copyright 2000-2013 Kitware, Inc. Copyright 2000-2011 Insight Software Consortium notices of original copyright by their contributors; see each source Copyright (C) 1996-2018 Free Software Foundation, Inc. Copyright (c) 2010 Chris Davis, Niels Provos, and Nick Mathewson Copyright (c) 2010-2012 Niels Provos and Nick Mathewson Copyright (c) 1996, David Mazieres Copyright (c) 2008, Damien Miller Copyright (c) 2002-2007 Niels Provos Copyright (c) 2002-2006 Niels Provos Copyright (c) 2009-2012 Niels Provos, Nick Mathewson Copyright 2000-2009 Niels Provos Copyright 2009-2012 Niels Provos and Nick Mathewson Copyright 2000-2007 Niels Provos Copyright 2007-2012 Niels Provos, Nick Mathewson Copyright 2003-2009 Niels Provos Copyright 2006-2007 Niels Provos Copyright 2007-2012 Nick Mathewson and Niels Provos Copyright (c) 2005-2007 Niels Provos Copyright (c) 2003-2009 Niels Provos Copyright 2007-2012 Niels Provos and Nick Mathewson Copyright (c) 2007 Sun Microsystems. All rights reserved. Copyright (c) 2008-2012 Niels Provos, Nick Mathewson Copyright 2002 Christopher Clark Copyright 2005-2012 Nick Mathewson Copyright 2001-2007 Niels Provos Copyright (c) 2012 Niels Provos and Nick Mathewson Copyright (c) 2000 Dug Song Copyright (c) 1993 The Regents of the University of California. Copyright (c) 1998 Todd C. Miller Copyright (c) 2003 Michael A. Davis Copyright (c) 2007 Sun Microsystems Copyright (c) 2002 Christopher Clark Copyright (c) 2006 Maxim Yegorushkin Copyright (c) 2010 BitTorrent, Inc. Copyright (c) 2005-2012 Niels Provos and Nick Mathewson Copyright (c) 1993 Copyright 2003 Michael A. Davis Copyright 2003-2007 Niels Provos Copyright 2008-2012 Niels Provos and Nick Mathewson Copyright (c) 2003-2007 Niels Provos Copyright (c) 2013 Niels Provos and Nick Mathewson Copyright (c) 2009-2012 Nick Mathewson and Niels Provos Copyright (c) 2007-2013 Niels Provos and Nick Mathewson Copyright (c) 2012 Ross Lagerwall tinytest.c -- Copyright 2009-2012 Nick Mathewson tinytest.h -- Copyright 2009-2012 Nick Mathewson tinytestmacros.h -- Copyright 2009-2012 Nick Mathewson Libevent is available for use under the following license, commonly known as the 3-clause (or "modified") BSD license: ============================== Copyright (c) 2000-2007 Niels Provos Copyright (c) 2007-2012 Niels Provos and Nick Mathewson Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. The name of the author may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ============================== Portions of Libevent are based on works by others, also made available by them under the three-clause BSD license above. The copyright notices are available in the corresponding source files; the license is as above. Here's a list: log.c: Copyright (c) 2000 Dug Song Copyright (c) 1993 The Regents of the University of California. strlcpy.c: Copyright (c) 1998 Todd C. Miller win32select.c: Copyright (c) 2003 Michael A. Davis evport.c: Copyright (c) 2007 Sun Microsystems ht-internal.h: Copyright (c) 2002 Christopher Clark minheap-internal.h: Copyright (c) 2006 Maxim Yegorushkin ============================== The arc4module is available under the following, sometimes called the "OpenBSD" license: Copyright (c) 1996, David Mazieres Copyright (c) 2008, Damien Miller Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ============================== The Windows timer code is based on code from libutp, which is distributed under this license, sometimes called the "MIT" license. Copyright (c) 2010 BitTorrent, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================== The wepoll module is available under the following, sometimes called the "FreeBSD" license: Copyright 2012-2020, Bert Belder All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ============================== The ssl-client-mbedtls.c is available under the following license: Copyright (C) 2006-2015, ARM Limited, All Rights Reserved SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. This file is part of mbed TLS (https://tls.mbed.org) Software: grpc 1.36.1 Copyright notice: Copyright 2015 The gRPC Authors Copyright 2016 The gRPC Authors Copyright 2018 The gRPC Authors Copyright 2019 The gRPC Authors Copyright 2018 The gRPC Authors Copyright © 2018 gRPC. Copyright 2016 gRPC authors. Copyright 2017 gRPC authors. Copyright 2019 gRPC authors. Copyright (C) 1995, 1996, 1997, and 1998 WIDE Project. Copyright (C) 2009 - 2013 by Daniel Stenberg et al Copyright (c) 2004, 2006-2010 Michael Roth Copyright (c) 2004-2009 Michael Roth Copyright (c) 2004-2010 Michael Roth Copyright (c) 2006-2008 Michael Roth Copyright (c) 2009-2011, Google Inc. Copyright (c) 2018, Google Inc. Copyright 2007 Google Inc. All Rights Reserved. Copyright 2008 Google Inc. Copyright 2013 Google Inc. Copyright 2014 Google Inc. Copyright 2014 gRPC authors. Copyright 2014, Google Inc. Copyright 2015 The gRPC Authors Copyright 2015 gRPC authors. Copyright 2015, Google Inc. Copyright 2015-2016 gRPC authors. Copyright 2015-2017 gRPC authors. Copyright 2016 Google Inc. Copyright 2016 The Chromium Authors. Copyright 2016 gRPC authors. Copyright 2016, Google Inc. Copyright 2017 The gRPC Authors Copyright 2017 gRPC authors. Copyright 2018 The Bazel Authors. Copyright 2018 The gRPC Authors Copyright 2018 The gRPC Authors. Copyright 2018 gRPC Authors. Copyright 2018 gRPC authors. Copyright 2018, gRPC Authors Copyright 2019 Istio Authors. All Rights Reserved. Copyright 2019 The Bazel Authors. Copyright 2019 The gRPC Authors Copyright 2019 The gRPC Authors. Copyright 2019 The gRPC authors. Copyright 2019 gRPC authors. Copyright 2019 the gRPC authors. Copyright 2019, Google Inc. Copyright 2020 The gRPC Authors Copyright 2020 The gRPC Authors. Copyright 2020 The gRPC authors. Copyright 2020 gRPC authors. Copyright 2020 the gRPC authors. Copyright 2020 王一 Wang Yi Copyright 2021 The gRPC Authors Copyright 2021 The gRPC authors. Copyright 2021 gRPC authors. Copyright 2021 the gRPC authors. Copyright 2015 The gRPC Authors Copyright 2017 The gRPC Authors Copyright 2015 gRPC authors. Copyright 2016 gRPC authors. Copyright 2020 The gRPC Authors Software: cmake-modules cf2e087039f81d13e687cf6c2b1b382b9c1e756f Copyright notice: Copyright 2009 Kitware, Inc. Copyright 2009 Will Dicharry Copyright 2005-2009 Kitware, Inc. Copyright Iowa State University 2009-2010. Copyright 2006-2009 Kitware, Inc. Copyright 2006-2008 Andreas Schneider Copyright 2007 Wengo Copyright 2007 Mike Jackson Copyright 2008 Andreas Pakulat Copyright 2008-2010 Philip Lowman Copyright 2009 Alexander Neundorf Copyright (c) 2012 - 2017, Lars Bilke Copyright (c) 2012-2016 Sascha Kratky Copyright 2012-2018 Sascha Kratky Copyright (c) 2012-2018, OpenGeoSys Community (http://www.opengeosys.org) Copyright (c) 2012 - 2015, Lars Bilke Copyright 2008-2009 Philip Lowman Copyright 2010 Iowa State University (Ryan Pavlik ) Copyright 2000-2009 Kitware, Inc., Insight Software Consortium Copyright 2010-2011 Kitware, Inc. Copyright Iowa State University 2009-2011 Boost Software License - Version 1.0 - August 17th, 2003 Permission is hereby granted, free of charge, to any person or organization obtaining a copy of the software and accompanying documentation covered by this license (the "Software") to use, reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative works of the Software, and to permit third-parties to whom the Software is furnished to do so, all subject to the following: The copyright notices in the Software and this entire statement, including the above license grant, this restriction and the following disclaimer, must be included in all copies of the Software, in whole or in part, and all derivative works of the Software, unless such copies or derivative works are solely in the form of machine-executable object code generated by a source language processor. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Software: abseil-cpp 20200923.3 Copyright notice: Copyright 2016 Google Inc. All Rights Reserved. Copyright 2017 Google Inc. All Rights Reserved. Copyright 2017 The Abseil Authors. Copyright 2018 The Abseil Authors. Copyright 2019 The Abseil Authors. Copyright 2020 The Abseil Authors. Apache License Version 2.0, January 2004 https://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Software: c-ares 1.15.0 Copyright notice: Copyright (c) 2012 Dan Winship Copyright (C) 2005 by Dominick Meglio Copyright (C) 2009-2013 by Daniel Stenberg Copyright (C) 2003-2018 Free Software Foundation, Inc. Copyright (c) 2011 Daniel Stenberg Copyright (c) 1996-1999 by Internet Software Consortium. Copyright (C) 2005, 2013 by Dominick Meglio Copyright (C) 2017 by John Schember Copyright (C) 2008-2013 by Daniel Stenberg Copyright 2004 by Daniel Stenberg define ARESCOPYRIGHT "2004 - 2017 Daniel Stenberg, ." Copyright (c) 1996,1999 by Internet Software Consortium. Copyright (c) 2013 Roy Stogner Copyright (C) 2004-2005, 2007-2008, 2011-2015 Free Software Foundation, Inc. Copyright (C) 1996-2018 Free Software Foundation, Inc. Copyright (C) 2005 - 2010, Daniel Stenberg Copyright (c) 2012 Philip Withnall Copyright (C) 2004 - 2012 by Daniel Stenberg et al Copyright (C) 2004-2010 by Daniel Stenberg. Copyright (C) 2004-2009 by Daniel Stenberg. Copyright (C) 2004-2009 by Daniel Stenberg Copyright 2003 Google Inc. Copyright (c) 2008 Steven G. Johnson Copyright 2005 by Dominick Meglio. Copyright (C) 2017 by John Schember Copyright (C) 2010 Jeremy Lal Copyright (C) 2009 by Jakub Hrozek Copyright (c) 2012 Christian Persch Copyright (C) 2013 by Daniel Stenberg Copyright 2005, Google Inc. Copyright 2013, Google Inc. Copyright (C) 2011 Free Software Foundation, Inc. Copyright (C) 2018 by John Schember Copyright (c) 2008 Benjamin Kosnik Copyright 2005 by Dominick Meglio Copyright (C) 2004-2005, 2007, 2009, 2011-2015 Free Software Foundation, Inc. Copyright (c) 2012 Paolo Borelli Copyright (C) 2009 by Daniel Stenberg et al Copyright (C) 1996-2001, 2003-2015 Free Software Foundation, Inc. Copyright 1998, 2000 by the Massachusetts Institute of Technology. Copyright (C) 2004-2005, 2007-2009, 2011-2015 Free Software Foundation, Inc. Copyright (C) 2009-2018 Free Software Foundation, Inc. Copyright 2005 by Dominick Meglio. Copyright (c) 2013 Daniel Stenberg Copyright (C) 1994 X Consortium Copyright 2008 Google Inc. Copyright (C) 1999-2018 Free Software Foundation, Inc. Copyright 1998, 2011, 2013 by the Massachusetts Institute of Technology. Copyright 1998 by the Massachusetts Institute of Technology. Copyright (C) 2004-2010 by Daniel Stenberg Copyright (c) 2004 by Internet Systems Consortium, Inc. ("ISC") Copyright 1992-2018 Free Software Foundation, Inc. Copyright 2015, Google Inc. Copyright (C) 2004-2018 Free Software Foundation, Inc. Copyright (C) 2012 Free Software Foundation, Inc. Copyright (C) 2010-2012 by Daniel Stenberg Copyright (c) 2012 Xan Lopez - aresversion.h: copyright end range year is now 2013 Copyright (c) 2014, 2015 Google Inc.; contributed by Alexey Sokolov Copyright 2000 by the Massachusetts Institute of Technology. Copyright (C) 2009 - 2013 by Daniel Stenberg et al Copyright (C) 2004 by Daniel Stenberg et al Copyright (C) 2006-2018 Free Software Foundation, Inc. Copyright (C) 2008 - 2013 by Daniel Stenberg et al Copyright 2006, Google Inc. Copyright (C) 2009-2016 by Daniel Stenberg Copyright (C) 2004-2009 by Daniel Stenberg Copyright (C) 2004-2010 by Daniel Stenberg Copyright (c) 2007 - 2018, Daniel Stenberg with many contributors, see AUTHORS file. Copyright 2010 by Ben Greear Copyright 2007, Google Inc. Copyright (C) 1997-2018 Free Software Foundation, Inc. Copyright 1998, 2011 by the Massachusetts Institute of Technology. Copyright 1998, 2000 by the Massachusetts Institute of Technology. Copyright (C) 2004-2011 by Daniel Stenberg Copyright (C) 2008 - 2009 by Daniel Stenberg et al Copyright (C) 2005-2013 by Daniel Stenberg et al Copyright 2008, Google Inc. Copyright (C) 2004-2017 by Daniel Stenberg Copyright (C) 2016 by Daniel Stenberg Copyright (C) 2010-2013 by Daniel Stenberg Copyright (c) 1987-2001 The Regents of the University of California. Copyright (C) 2008 - 2012 by Daniel Stenberg et al Copyright (C) 2009-2013 by Daniel Stenberg et al Copyright (c) 2015 Bastien ROUCARIES Copyright 2000 by the Massachusetts Institute of Technology. Copyright 2005 Dominick Meglio Copyright 1998 by Daniel Stenberg Copyright (c) 2011 Daniel Richard G. Copyright (C) 2001-2018 Free Software Foundation, Inc. Copyright (C) 2004 - 2011 by Daniel Stenberg et al Copyright (C) 2004 - 2013 by Daniel Stenberg et al Copyright (C) 2012 Marko Kreen Copyright (C) 2008-2010 by Daniel Stenberg Copyright (C) 2008 by Daniel Stenberg et al Copyright 1998 by the Massachusetts Institute of Technology. Copyright (C) 2008-2010 by Daniel Stenberg Copyright (C) 2017 by John Schember Copyright (C) 2004, 2011-2015 Free Software Foundation, Inc. Copyright (C) 2002-2018 Free Software Foundation, Inc. Copyright (C) 2007-2013 by Daniel Stenberg Copyright 2009 Google Inc. Copyright (C) 1994-2018 Free Software Foundation, Inc. Copyright (C) 1992-1996, 1998-2012 Free Software Foundation, Inc. Copyright (C) 2014 Free Software Foundation, Inc. Copyright (c) 2012 Zack Weinberg Copyright 1998 by the Massachusetts Institute of Technology. Permission to use, copy, modify, and distribute this software and its documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of M.I.T. not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission. M.I.T. makes no representations about the suitability of this software for any purpose. It is provided "as is" without express or implied warranty. Software: numpy 1.17.0 Copyright notice: Copyright (c) 1995, 1996, 1997 Jim Hugunin, hugunin@mit.edu Copyright 2014 Melissa O'Neill Copyright (c) 2006, University of Georgia and Pierre G.F. Gerard-Marchant All rights reserved. Copyright 1999-2004 Pearu Peterson all rights reserved, Pearu Peterson Copyright (c) 2009-2019: Jeff Bezanson, Stefan Karpinski, Viral B. Shah, and other contributors: Copyright (c) 2003-2005, Jean-Sebastien Roy (js@jeannot.org) copyright = '2008-2019, The SciPy community' Copyright 2002 Pearu Peterson all rights reserved, Pearu Peterson Copyright (c) 2019 NumPy Developers Copyright (c) 2005-2017, NumPy Developers. Copyright (c) 2018 Melissa E. O'Neill Copyright (c) 2008 Ian Bicking and Contributors Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. Copyright 1999, 2000, 2001 Regents of the University of California. Copyright (c) 2005, NumPy Developers Copyright (c) 2019 Kevin Sheppard. All rights reserved. Copyright 2010-2012, D. E. Shaw Research. Copyright (c) 2011 by Mark Wiebe (mwwiebe@gmail.com) Copyright 2006, Dean Edwards Copyright (c) 2014 Ryan Juckett Copyright 2001-2005 Pearu Peterson all rights reserved, Pearu Peterson copyright = u'2017-2018, NumPy Developers' Copyright (C) 2004-2018 Max-Planck-Society \author Martin Reinecke Copyright (c) 2010-2011 by Mark Wiebe (mwwiebe@gmail.com) Copyright (c) 2012 Stephen Montgomery-Smith Copyright (c) 2005-2019, NumPy Developers. Copyright (c) 2011 Enthought, Inc Copyright (c) 2007, 2011 David Schultz Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. Copyright (c) 2015 Pauli Virtanen All rights reserved. Copyright 1999,2000 Pearu Peterson all rights reserved, Pearu Peterson Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, All rights reserved. Copyright (c) 2005-2015, NumPy Developers. Copyright (c) 2011 by Enthought, Inc. Copyright (c) 2010 by Mark Wiebe (mwwiebe@gmail.com) Copyright 2000 Pearu Peterson all rights reserved, Pearu Peterson f90: Copyright Absoft Corporation 1994-1998 mV2; Cray Research, Inc. 1994-1996 CF90 (2.x.x.x f36t87) Version 2.3 Wed Apr 19, 2006 13:05:16 Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. f90: Copyright Absoft Corporation 1994-2002; Absoft Pro FORTRAN Version 8.0 Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. Copyright 1999--2011 Pearu Peterson all rights reserved, Pearu Peterson Copyright 1999 - 2011 Pearu Peterson all rights reserved. Copyright 2015 Robert Kern Copyright (c) 2015 Melissa E. O'Neill Copyright (c) 2005-2019, NumPy Developers. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the NumPy Developers nor the names of any contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Software: Python 3.7.5 Copyright notice: Copyright (c) 1999-2000 by Secret Labs AB Copyright (C) 2005-2007 Gregory P. Smith (greg@krypto.org) Copyright (c) 2003. . Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright 1996,1997 by Oliver Andrich, Koblenz, Germany. Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives. All rights reserved. Copyright 1994 by Lance Ellinghouse Cathedral City, California Republic, United States of America. Copyright (C) 2001 Python Software Foundation Barry Warsaw , 2000. libffi 2.00-beta - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 2008-2012 Stefan Krah. All rights reserved. 2001-07-01 fl added BIGCHARSET support (from Martin von Loewis) ``'Copyright 1991-1995 Stichting Mathematisch Centrum, Amsterdam'`` Copyright (C) 2003 Python Software Foundation Copyright (C) 2001-2016 Vinay Sajip. All Rights Reserved. Copyright 1995-1997, Automatrix, Inc., all rights reserved. Copyright (c) 2002 MyCompanyName. All rights reserved. %version%, (c) 2001-2019 Python Software Foundation. Copyright (c) 2004 by Peter Astrand Copyright (c) 1999-2002 by Fredrik Lundh. Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.\n\ AIX ABI support (c) 2002 Free Software Foundation, Inc. ( Copyright (c) 2011 Stefan Krah. All rights reserved. ) 2013-02-04 mrab added fullmatch primitive 2003-10-17 gn implemented non recursive scheme 2003-04-18 mvl fully support 4-byte codes Copyright (C) 1996-2018 Free Software Foundation, Inc. portions copyright 2001, Autonomous Zones Industries, Inc., all rights... Copyright (c) 1999-2002 by Secret Labs AB. Copyright (C) 1986 Gary S. Brown. You may use this program, or code or tables extracted from it, as desired without restriction. -- Copyright (c) IBM Corporation, 2003, 2008. All rights reserved. -- ; Copyright (c) 2008-2016 Stefan Krah. All rights reserved. Copyright (c) 2001-2019 Python Software Foundation.\n\ Copyright 2008 Armin Ronacher. Copyright © 2000 BeOpen.com. All rights reserved. (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) Copyright (c) 2005-2006 ActiveState Software Inc. Copyright (C) 1994 Steen Lumholt. Copyright (c) 1999 by Fredrik Lundh. libffi - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 2006-2008 Alexander Chemeris Copyright (C) 2002 Lars Gustaebel Copyright (c) 1999-2003 Steve Purcell Darwin ABI support (c) 2001 John Hornkvist Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) Copyright (c) 1999-2008 by Fredrik Lundh i.e., "Copyright © 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 -- Copyright (c) IBM Corporation, 2005, 2009. All rights reserved. -- Copyright (c) 2001-2017 Expat maintainers Copyright (c) 2001-2012 Python Software Foundation. All Rights Reserved. Copyright (C) 2002-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org (c) 2002 Gregory P. Ward. All Rights Reserved. Copyright (c) 2000 BeOpen.com.\n\ Copyright (C) 2001-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org Copyright (C) 2003-2004 Federico Di Gregorio 2001-05-14 fl fixes for 1.5.2 compatibility Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved. Copyright (C) 1995, 1996, 1997, 1998, and 1999 WIDE Project. copyright, i.e., "Copyright © 2001-2019 Python Software Foundation; All Rights Reserved" are retained in Python |release| alone or in any derivative version prepared by Licensee. Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw, Thomas Wouters, Anthony Baxter Contact: email-sig@python.org Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright 2001-2017 by Vinay Sajip. All Rights Reserved. Copyright (c) 2000 Doug White, 2006 James Knight, 2007 Christian Heimes All rights reserved. Copyright (C) 1999-2001 Gregory P. Ward. Copyright (c) 1999-2002 by Fredrik Lundh + Copyright 2007 Python Software Foundation. else if (config == (void )2000 && (c) == 0x9B1D) { \ Copyright (c) 1999-2002 by Secret Labs AB 2002-11-09 fl fixed empty sub/subn return type Copyright 2009 Gabriel A. Genellina Copyright (c) 2003-2009 by Fredrik Lundh. All rights reserved. Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. (c) Copyright Guido van Rossum, 2000. Copyright (C) 1995, 1996, 1997, and 1998 WIDE Project. Copyright (C) 2011-2012 Vinay Sajip. Copyright 2006 Google, Inc. All Rights Reserved. (c) Copyright Marc-Andre Lemburg, 2005. Copyright (C) YEAR ORGANIZATION FIRST AUTHOR , YEAR. Copyright (c) 1999-2000, Marc-Andre Lemburg; mailto:mal@lemburg.com Copyright (c) 1995-2000, Corporation for National Research Initiatives. Copyright (c) 1999 Toby Dickenson Copyright (C) 2001,2002 Python Software Foundation csv package unit tests Copyright (C) 2005, 2006 Martin von Löwis Licensed to PSF under a Contributor Agreement. Copyright (c) 1997 by Fredrik Lundh Copyright (c) 2002-2006 Python Software Foundation. All rights reserved. Copyright (c) 2002 Roger Sayle Copyright 1995-1996 by Fred L. Drake, Jr. and Virginia Polytechnic Institute and State University, Blacksburg, Virginia, USA. types.c - Copyright (c) 1996, 1998 Red Hat, Inc. Copyright 2000, Mojam Media, Inc., all rights reserved. Copyright (C) 1994 X Consortium Copyright (C) 2002-2004 Python Software Foundation Copyright (C) 2004-2006 Python Software Foundation Authors: Baxter, Wouters and Warsaw Contact: email-sig@python.org Copyright (c) 2002 Jorge Acereda & darwin.S - Copyright (c) 1996, 1998, 2001, 2002, 2003 Red Hat, Inc. Copyright © 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. "Copyright 1995-1996 by Virginia Polytechnic Institute & State\n\ Copyright 2012, Samuel Neves . You may use this under the terms of the CC0, the OpenSSL Licence, or the Apache Public License 2.0, at your option. The terms of these licenses can be found at: Copyright (C) 1995-2011 Jean-loup Gailly and Mark Adler x86-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright 1991-1995, Stichting Mathematisch Centrum, all rights reserved. darwin64.S - Copyright (c) 2006 Free Software Foundation, Inc. Copyright (C) 2001-2006 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org Copyright (c) 2005 Don Owens All rights reserved. win32.S - Copyright (c) 1996, 1998, 2001, 2002 Red Hat, Inc. (c) Copyright 2005, Marc-Andre Lemburg (mal@lemburg.com). library/xml.etree.elementtree,,:include, Copyright (c) . ffi.c - Copyright (c) 1998 Geoffrey Keating Copyright 2006 Georg Brandl. Copyright (C) 2005-2010 Gerhard Häring (c) 2013-2017 Christian Heimes Copyright 1992-2018 Free Software Foundation, Inc. Copyright (c) 1990-1995, Stichting Mathematisch Centrum. Copyright (C) 2001-2017 Vinay Sajip. All Rights Reserved. i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019 Python Software Foundation; 2001-10-24 fl added finditer primitive (for 2.2 only) Copyright (C) 2001-2006 Python Software Foundation Author: Keith Dart Contact: email-sig@python.org Copyright (c) 1999-2009 by Fredrik Lundh. (c) 2000 Peter Bosch. All Rights Reserved. ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (C) 2012 Free Software Foundation, Inc. Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org Copyright 1994 by Lance Ellinghouse, Cathedral City, California Republic, United States of America. Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. Copyright (C) 2006 - 2010 Gregor Lingl email: glingl@aon.at Copyright © 2001-2019 Python Software Foundation. All rights reserved. Copyright (c) 2009,2010 Zmanda Inc. Copyright (c) 1998-2008 The OpenSSL Project. All rights reserved. Copyright 1996 by Sam Rushing Copyright (c) 1998-2000 Thai Open Source Software Center Ltd and Clark Cooper (c) Craig Reese, Joe Campbell and Jeff Poskanzer 1989 / Copyright (c) 1999, 2000, 2001 Steve Purcell This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. -- Copyright (c) IBM Corporation, 2005, 2008. All rights reserved. -- Copyright (C) 2001-2012 Python Software Foundation. All Rights Reserved. dnl Copyright © 2004 Scott James Remnant . Copyright (C) 2002-2007 Python Software Foundation Contact: email-sig@python.org Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. Copyright (c) 2009,2010 Dustin J. Mitchell Copyright (c) 2002 Bo Thorsen 2001-10-21 fl added sub/subn primitive copyright, i.e., "Copyright © 2001-2018 Python Software Foundation; All Rights Reserved" are retained in Python 3.7 alone or in any derivative version prepared by Licensee. Copyright (C) 2011-2013 Vinay Sajip. Copyright (c) 1991, 2000, 2001 by Lucent Technologies. Copyright (c) 2010 Python Software Foundation. All Rights Reserved. Copyright 1992-1994, David Gottner " SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB "; ; Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2003-2005 by Peter Astrand Copyright (c) 1999 by Secret Labs AB libffi PyOBJC - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 1999-2009 by Fredrik Lundh Copyright 2007 Google, Inc. All Rights Reserved. -- Copyright (c) IBM Corporation, 2004, 2008. All rights reserved. -- Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019 Python Software Foundation. All rights reserved. Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2000 BeOpen.com. All rights reserved. Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved. Copyright (C) 2000 Luke Kenneth Casson Leighton 2001-12-07 fl fixed memory leak in sub/subn (Guido van Rossum) -- Copyright (c) IBM Corporation, 2001, 2008. All rights reserved. -- Virginia, USA. Portions copyright 1991-1995 by Stichting Mathematisch\n\ Copyright (c) 2004 Free Software Foundation, Inc. so portions are Copyright (C) 2001,2002 Python Software Foundation, and were written by Barry Warsaw. Copyright (C) 2004-2010 Gerhard Häring Copyright (c) 2004 Python Software Foundation. (c) Copyright 2000 Guido van Rossum. Copyright 2007 Georg Brandl. Copyright (c) 1999 by Secret Labs AB. Copyright (c) 2002 Unicode, Inc. All Rights reserved. Copyright 2009 Brian Quinlan. All Rights Reserved. Copyright (c) 2008-2009, Google Inc. Copyright (c) 2001-2006 Twisted Matrix Laboratories. (c) Copyright CNRI, All Rights Reserved. NO WARRANTY. License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet Copyright (c) Corporation for National Research Initiatives. Copyright (c) 2008-2016 Stefan Krah. All rights reserved. Copyright 2001-2016 by Vinay Sajip. All Rights Reserved. if (config == (void )2000 && (c) == 0x20B9F) { \ ppc-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 2001-2006 Gregory P. Ward. All rights reserved. Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. Copyright 2000 by Timothy O'Malley Copyright (C) 2007-2012 Michael Foord & the mock team E-mail: fuzzyman AT voidspace DOT org DOT uk Copyright (C) 2011-2014 Vinay Sajip. Copyright (c) 2002 Ranjit Mathew ppc-darwin.h - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. ppc64-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S x86-ffi64.c - Copyright (c) 2002 Bo Thorsen Copyright (C) 2001-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2004, Outercurve Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives.\n\ Copyright 1999, Bioreason, Inc., all rights reserved. 2001-10-20 fl added split primitive; re-enable unicode for 1.6/2.0/2.1 Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. Copyright (c) 2008 by Christian Heimes Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a contributor agreement. ppc-darwin.S - Copyright (c) 2000 John Hornkvist Copyright (c) 2000, BeOpen.com. Copyright (C) 2001-2010 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (C) 2001-2007 Python Software Foundation Author: Anthony Baxter Contact: email-sig@python.org Copyright (c) 2004 by Fredrik Lundh Copyright Disney Enterprises, Inc. All Rights Reserved. ffi.c - Copyright (c) 1996, 1998, 1999, 2001 Red Hat, Inc. Copyright (C) 2003-2013 Python Software Foundation Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Copyright (c) . prepcif.c - Copyright (c) 1996, 1998 Red Hat, Inc. self.assertEqual(list(c), list(range(2,2000))) Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org -- Copyright (c) IBM Corporation, 2000, 2008. All rights reserved. -- Copyright (C) 2012 Christian Heimes (christian@python.org) fficommon.h - Copyright (c) 1996 Red Hat, Inc. Copyright (C) 2012-2016 Christian Heimes (christian@python.org) Copyright (C) 2004-2005 Gerhard Häring (c) 2002 Python Software Foundation. All Rights Reserved. Copyright (c) 2004 by Secret Labs AB, http://www.pythonware.com Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2006-2008, R Oudkerk Licensed to PSF under a Contributor Agreement. .. Copyright 1995 Virginia Polytechnic Institute and State University and Fred L. Drake, Jr. This copyright notice must be distributed on all copies, but this document otherwise may be distributed as part of the Python distribution. No fee may be charged for this document in any representation, either on paper or electronically. This restriction does not affect other elements in a distributed package in any way. Copyright 2012-2013 by Larry Hastings. Copyright (C) 2002-2006 Python Software Foundation Contact: email-sig@python.org email package unit tests for (optional) Asian codecs Copyright (c) 2002 Peter O'Gorman Copyright 2007 Google Inc. Copyright (c) 1999 by Fredrik Lundh Copyright (C) 2001-2010 Python Software Foundation Contact: email-sig@python.org email package unit tests Copyright (c) 2003-2004 by Fredrik Lundh. All rights reserved. Copyright (c) 1991-1999 Unicode, Inc. All Rights reserved. Copyright (c) 2000-2017 Expat development team Licensed under the MIT license: Copyright (c) 1997-2000 Thai Open Source Software Center Ltd Copyright (c) 1998 The Open Group Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved.\ %version%, (c) 2001-2016 Python Software Foundation. Copyright (c) 2000-2010, eGenix.com Software GmbH; mailto:info@egenix.com Copyright (C) 2005 Gerhard Häring copyright = '2001-%s, Python Software Foundation' % time.strftime('%Y') Copyright (c) 1996-2008 Red Hat, Inc and others. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a Contributor Agreement. (c) 2001-2016 Python Software Foundation. %VERSION%, (c) 2001-2019 Python Software Foundation. Copyright (C) 1997, 2002, 2003, 2007, 2008 Martin von Loewis Copyright (c) 2013 Marek Majkowski Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved. Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, All rights reserved. Copyright (c) 1998, 1999, 2000 Thai Open Source Software Center Ltd and Clark Cooper 2001-04-15 fl export copyright as Python attribute, not global 2001-04-28 fl added copy methods (work in progress) Copyright (C) 2002, 2003 Python Software Foundation. Copyright (c) 2004, 2005, 2006 Python Software Foundation. dnl Copyright © 2012-2015 Dan Nicholson Copyright (c) 1999-2009 by Secret Labs AB. All rights reserved. Copyright (c) 2003-2010 Python Software Foundation This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. Copyright (c) 1991-1995 Stichting Mathematisch Centrum. All rights reserved. 2001-10-18 fl fixed group reset issue (from Matthew Mueller) Copyright (C) 1992-1996, 1998-2012 Free Software Foundation, Inc. -- Copyright (c) IBM Corporation, 1981, 2008. All rights reserved. -- Copyright (C) 2000 Bastian Kleineidam ppc-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S Copyright (c) 2001 John Beniton Portions copyright 1991-1995 by Stichting Mathematisch Centrum, Amsterdam, The Netherlands. Copying is permitted under the terms associated with the main Python distribution, with the additional restriction that this additional notice be included and maintained on all distributed copies. A. HISTORY OF THE SOFTWARE ========================== Python was created in the early 1990s by Guido van Rossum at Stichting Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands as a successor of a language called ABC. Guido remains Python's principal author, although it includes many contributions from others. In 1995, Guido continued his work on Python at the Corporation for National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) in Reston, Virginia where he released several versions of the software. In May 2000, Guido and the Python core development team moved to BeOpen.com to form the BeOpen PythonLabs team. In October of the same year, the PythonLabs team moved to Digital Creations, which became Zope Corporation. In 2001, the Python Software Foundation (PSF, see https://www.python.org/psf/) was formed, a non-profit organization created specifically to own Python-related Intellectual Property. Zope Corporation was a sponsoring member of the PSF. All Python releases are Open Source (see http://www.opensource.org for the Open Source Definition). Historically, most, but not all, Python releases have also been GPL-compatible; the table below summarizes the various releases. Release Derived Year Owner GPL- from compatible? (1) 0.9.0 thru 1.2 1991-1995 CWI yes 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes 1.6 1.5.2 2000 CNRI no 2.0 1.6 2000 BeOpen.com no 1.6.1 1.6 2001 CNRI yes (2) 2.1 2.0+1.6.1 2001 PSF no 2.0.1 2.0+1.6.1 2001 PSF yes 2.1.1 2.1+2.0.1 2001 PSF yes 2.1.2 2.1.1 2002 PSF yes 2.1.3 2.1.2 2002 PSF yes 2.2 and above 2.1.1 2001-now PSF yes Footnotes: (1) GPL-compatible doesn't mean that we're distributing Python under the GPL. All Python licenses, unlike the GPL, let you distribute a modified version without making your changes open source. The GPL-compatible licenses make it possible to combine Python with other software that is released under the GPL; the others don't. (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, because its license has a choice of law clause. According to CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 is "not incompatible" with the GPL. Thanks to the many outside volunteers who have worked under Guido's direction to make these releases possible. B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON =============================================================== PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -------------------------------------------- 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"), and the Individual or Organization ("Licensee") accessing and otherwise using this software ("Python") in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python alone or in any derivative version, provided, however, that PSF's License Agreement and PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. 3. In the event Licensee prepares a derivative work that is based on or incorporates Python or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python. 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between PSF and Licensee. This License Agreement does not grant permission to use PSF trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By copying, installing or otherwise using Python, Licensee agrees to be bound by the terms and conditions of this License Agreement. BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 ------------------------------------------- BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the Individual or Organization ("Licensee") accessing and otherwise using this software in source or binary form and its associated documentation ("the Software"). 2. Subject to the terms and conditions of this BeOpen Python License Agreement, BeOpen hereby grants Licensee a non-exclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use the Software alone or in any derivative version, provided, however, that the BeOpen Python License is retained in the Software, alone or in any derivative version prepared by Licensee. 3. BeOpen is making the Software available to Licensee on an "AS IS" basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 5. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 6. This License Agreement shall be governed by and interpreted in all respects by the law of the State of California, excluding conflict of law provisions. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between BeOpen and Licensee. This License Agreement does not grant permission to use BeOpen trademarks or trade names in a trademark sense to endorse or promote products or services of Licensee, or any third party. As an exception, the "BeOpen Python" logos available at http://www.pythonlabs.com/logos.html may be used according to the permissions granted on that web page. 7. By copying, installing or otherwise using the software, Licensee agrees to be bound by the terms and conditions of this License Agreement. CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 --------------------------------------- 1. This LICENSE AGREEMENT is between the Corporation for National Research Initiatives, having an office at 1895 Preston White Drive, Reston, VA 20191 ("CNRI"), and the Individual or Organization ("Licensee") accessing and otherwise using Python 1.6.1 software in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, CNRI hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python 1.6.1 alone or in any derivative version, provided, however, that CNRI's License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet using the following URL: http://hdl.handle.net/1895.22/1013". 3. In the event Licensee prepares a derivative work that is based on or incorporates Python 1.6.1 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python 1.6.1. 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. This License Agreement shall be governed by the federal intellectual property law of the United States, including without limitation the federal copyright law, and, to the extent such U.S. federal law does not apply, by the law of the Commonwealth of Virginia, excluding Virginia's conflict of law provisions. Notwithstanding the foregoing, with regard to derivative works based on Python 1.6.1 that incorporate non-separable material that was previously distributed under the GNU General Public License (GPL), the law of the Commonwealth of Virginia shall govern this License Agreement only as to issues arising under or with respect to Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between CNRI and Licensee. This License Agreement does not grant permission to use CNRI trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By clicking on the "ACCEPT" button where indicated, or by copying, installing or otherwise using Python 1.6.1, Licensee agrees to be bound by the terms and conditions of this License Agreement. ACCEPT CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 -------------------------------------------------- Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Permission to use, copy, modify, and distribute this software and its documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of Stichting Mathematisch Centrum or CWI not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission. STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. Software: Python 3.8.5 Copyright notice: copyright, i.e., "Copyright © 2001-2018 Python Software Foundation; All Rights Reserved" are retained in Python 3.8 alone or in any derivative version prepared by Licensee. Copyright (c) 1999-2000 by Secret Labs AB Copyright (C) 2005-2007 Gregory P. Smith (greg@krypto.org) Copyright (c) 2003. . Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright 1996,1997 by Oliver Andrich, Koblenz, Germany. Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives. All rights reserved. Copyright 1994 by Lance Ellinghouse Cathedral City, California Republic, United States of America. Copyright (C) 2001 Python Software Foundation Barry Warsaw , 2000. Copyright (c) 2008-2012 Stefan Krah. All rights reserved. 2001-07-01 fl added BIGCHARSET support (from Martin von Loewis) ``'Copyright 1991-1995 Stichting Mathematisch Centrum, Amsterdam'`` Copyright (C) 2003 Python Software Foundation Copyright (C) 2001-2016 Vinay Sajip. All Rights Reserved. Copyright 1995-1997, Automatrix, Inc., all rights reserved. Copyright (c) 2002 MyCompanyName. All rights reserved. Copyright (c) 2004 by Peter Astrand Copyright (c) 1999-2002 by Fredrik Lundh. Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.\n\ AIX ABI support (c) 2002 Free Software Foundation, Inc. ( Copyright (c) 2011 Stefan Krah. All rights reserved. ) 2013-02-04 mrab added fullmatch primitive 2003-10-17 gn implemented non recursive scheme 2003-04-18 mvl fully support 4-byte codes Copyright (c) 1999-2002 by Secret Labs AB. portions copyright 2001, Autonomous Zones Industries, Inc., all rights... Copyright © 2013 W3C® (MIT, ERCIM, Keio, Beihang), All Rights Reserved. Copyright (C) 1986 Gary S. Brown. You may use this program, or code or tables extracted from it, as desired without restriction. -- Copyright (c) IBM Corporation, 2003, 2008. All rights reserved. -- ; Copyright (c) 2008-2016 Stefan Krah. All rights reserved. Copyright (c) 2001-2020 Python Software Foundation. All rights reserved. Copyright 2008 Armin Ronacher. Copyright © 2000 BeOpen.com. All rights reserved. (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) Copyright (c) 2005-2006 ActiveState Software Inc. Copyright (C) 1994 Steen Lumholt. Copyright (c) 1999 by Fredrik Lundh. libffi - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 2006-2008 Alexander Chemeris Copyright (C) 2002 Lars Gustaebel Copyright (c) 1999-2003 Steve Purcell Darwin ABI support (c) 2001 John Hornkvist Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) Copyright (c) 1999-2008 by Fredrik Lundh i.e., "Copyright © 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 -- Copyright (c) IBM Corporation, 2005, 2009. All rights reserved. -- Copyright (c) 2001-2017 Expat maintainers Copyright (c) 2001-2012 Python Software Foundation. All Rights Reserved. Copyright (C) 2002-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org (c) 2002 Gregory P. Ward. All Rights Reserved. Copyright (c) 2000 BeOpen.com.\n\ Copyright (C) 2001-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org Copyright (C) 2003-2004 Federico Di Gregorio 2001-05-14 fl fixes for 1.5.2 compatibility Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved. Copyright (c) 2013 W3C(R) (MIT, ERCIM, Keio, Beihang), All Rights Reserved. Copyright (C) 1995, 1996, 1997, 1998, and 1999 WIDE Project. Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw, Thomas Wouters, Anthony Baxter Contact: email-sig@python.org Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright 2001-2017 by Vinay Sajip. All Rights Reserved. Copyright (c) 2000 Doug White, 2006 James Knight, 2007 Christian Heimes All rights reserved. Copyright (C) 1999-2001 Gregory P. Ward. Copyright (c) 1999-2002 by Fredrik Lundh + Copyright 2007 Python Software Foundation. else if (config == (void )2000 && (c) == 0x9B1D) { \ Copyright (c) 1999-2002 by Secret Labs AB 2002-11-09 fl fixed empty sub/subn return type Copyright 2009 Gabriel A. Genellina Copyright (c) 2003-2009 by Fredrik Lundh. All rights reserved. Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. (c) Copyright Guido van Rossum, 2000. Copyright (C) 1995, 1996, 1997, and 1998 WIDE Project. Copyright (C) 2011-2012 Vinay Sajip. Copyright 2006 Google, Inc. All Rights Reserved. (c) Copyright Marc-Andre Lemburg, 2005. Copyright (C) 1996-2014 Free Software Foundation, Inc. Copyright (C) YEAR ORGANIZATION FIRST AUTHOR , YEAR. Copyright (c) 1999-2000, Marc-Andre Lemburg; mailto:mal@lemburg.com Copyright (c) 1995-2000, Corporation for National Research Initiatives. Copyright (c) 1999 Toby Dickenson Copyright (C) 2001,2002 Python Software Foundation csv package unit tests Copyright (C) 2005, 2006 Martin von Löwis Licensed to PSF under a Contributor Agreement. Copyright (c) 1997 by Fredrik Lundh Copyright (c) 2002-2006 Python Software Foundation. All rights reserved. Copyright (c) 2002 Roger Sayle Copyright 1995-1996 by Fred L. Drake, Jr. and Virginia Polytechnic Institute and State University, Blacksburg, Virginia, USA. types.c - Copyright (c) 1996, 1998 Red Hat, Inc. Copyright 2000, Mojam Media, Inc., all rights reserved. Copyright (C) 1994 X Consortium copyright, i.e., "Copyright © 2001-2020 Python Software Foundation; All Rights Reserved" are retained in Python |release| alone or in any derivative version prepared by Licensee. Copyright (C) 2004-2006 Python Software Foundation Authors: Baxter, Wouters and Warsaw Contact: email-sig@python.org Copyright (C) 2002-2004 Python Software Foundation Copyright (c) 2002 Jorge Acereda & darwin.S - Copyright (c) 1996, 1998, 2001, 2002, 2003 Red Hat, Inc. Copyright © 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. "Copyright 1995-1996 by Virginia Polytechnic Institute & State\n\ x86-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (C) 1995-2011 Jean-loup Gailly and Mark Adler Copyright 1991-1995, Stichting Mathematisch Centrum, all rights reserved. Copyright © 2001-2020 Python Software Foundation. All rights reserved. (c) 2001-2020 Python Software Foundation. Copyright (C) 2001-2006 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org Copyright (c) 2005 Don Owens All rights reserved. darwin64.S - Copyright (c) 2006 Free Software Foundation, Inc. (c) Copyright 2005, Marc-Andre Lemburg (mal@lemburg.com). library/xml.etree.elementtree,,:include, Copyright (c) . ffi.c - Copyright (c) 1998 Geoffrey Keating Copyright 2006 Georg Brandl. Copyright (C) 2005-2010 Gerhard Häring (c) 2013-2017 Christian Heimes Copyright 1992-2018 Free Software Foundation, Inc. Copyright (C) 2003-2013 Python Software Foundation import copy import operator import pickle import unittest import plistlib import os import datetime import codecs import binascii import collections from test import support from io import BytesIO Copyright (C) 2001-2017 Vinay Sajip. All Rights Reserved. Copyright (c) 1990-1995, Stichting Mathematisch Centrum. 2001-10-24 fl added finditer primitive (for 2.2 only) Copyright (C) 2001-2006 Python Software Foundation Author: Keith Dart Contact: email-sig@python.org Copyright (c) 1999-2009 by Fredrik Lundh. (c) 2000 Peter Bosch. All Rights Reserved. Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. Copyright (C) 2012 Free Software Foundation, Inc. Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; Copyright 1994 by Lance Ellinghouse, Cathedral City, California Republic, United States of America. Copyright (C) 2006 - 2010 Gregor Lingl email: glingl@aon.at Copyright (c) 2009,2010 Zmanda Inc. Copyright (c) 1998-2008 The OpenSSL Project. All rights reserved. Copyright 1996 by Sam Rushing Copyright (c) 1998-2000 Thai Open Source Software Center Ltd and Clark Cooper (c) Craig Reese, Joe Campbell and Jeff Poskanzer 1989 / Copyright (c) 1999, 2000, 2001 Steve Purcell This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. -- Copyright (c) IBM Corporation, 2005, 2008. All rights reserved. -- Copyright (C) 2001-2012 Python Software Foundation. All Rights Reserved. dnl Copyright © 2004 Scott James Remnant . Copyright (C) 2002-2007 Python Software Foundation Contact: email-sig@python.org Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. Copyright (c) 2009,2010 Dustin J. Mitchell Copyright (c) 2002 Bo Thorsen 2001-10-21 fl added sub/subn primitive Copyright 1992-1994, David Gottner Copyright (C) 2011-2013 Vinay Sajip. Copyright (c) 1991, 2000, 2001 by Lucent Technologies. Copyright (c) 2010 Python Software Foundation. All Rights Reserved. " SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB "; ; Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2003-2005 by Peter Astrand Copyright (c) 1999 by Secret Labs AB libffi PyOBJC - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 1999-2009 by Fredrik Lundh Copyright 2007 Google, Inc. All Rights Reserved. -- Copyright (c) IBM Corporation, 2004, 2008. All rights reserved. -- Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2000 BeOpen.com. All rights reserved. %version%, (c) 2001-2020 Python Software Foundation. Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved. Copyright (C) 2000 Luke Kenneth Casson Leighton Copyright (C) 2005-2007 Gerhard Häring 2001-12-07 fl fixed memory leak in sub/subn (Guido van Rossum) -- Copyright (c) IBM Corporation, 2001, 2008. All rights reserved. -- Virginia, USA. Portions copyright 1991-1995 by Stichting Mathematisch\n\ Copyright (c) 2004 Free Software Foundation, Inc. so portions are Copyright (C) 2001,2002 Python Software Foundation, and were written by Barry Warsaw. Copyright (C) 2004-2010 Gerhard Häring Copyright (c) 2004 Python Software Foundation. (c) Copyright 2000 Guido van Rossum. Copyright 2007 Georg Brandl. Copyright (c) 1999 by Secret Labs AB. Copyright (c) 2002 Unicode, Inc. All Rights reserved. Copyright 2009 Brian Quinlan. All Rights Reserved. Copyright (c) 2008-2009, Google Inc. Copyright (c) 2001-2006 Twisted Matrix Laboratories. (c) Copyright CNRI, All Rights Reserved. NO WARRANTY. License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet Copyright (c) Corporation for National Research Initiatives. Copyright (c) 2008-2016 Stefan Krah. All rights reserved. Copyright 2001-2016 by Vinay Sajip. All Rights Reserved. if (config == (void )2000 && (c) == 0x20B9F) { \ ppc-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 2001-2006 Gregory P. Ward. All rights reserved. Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. Copyright 2000 by Timothy O'Malley Copyright (C) 2007-2012 Michael Foord & the mock team E-mail: fuzzyman AT voidspace DOT org DOT uk Copyright (C) 2011-2014 Vinay Sajip. ppc-darwin.h - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. ppc64-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S x86-ffi64.c - Copyright (c) 2002 Bo Thorsen Copyright (c) 2002 Ranjit Mathew Copyright (C) 2001-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2004, Outercurve Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives.\n\ Copyright 1999, Bioreason, Inc., all rights reserved. 2001-10-20 fl added split primitive; re-enable unicode for 1.6/2.0/2.1 Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. Copyright (c) 2008 by Christian Heimes Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a contributor agreement. ppc-darwin.S - Copyright (c) 2000 John Hornkvist Copyright (c) 2000, BeOpen.com. Copyright (C) 2001-2010 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (C) 2001-2007 Python Software Foundation Author: Anthony Baxter Contact: email-sig@python.org Copyright (c) 2004 by Fredrik Lundh Copyright Disney Enterprises, Inc. All Rights Reserved. ffi.c - Copyright (c) 1996, 1998, 1999, 2001 Red Hat, Inc. Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Copyright (c) . prepcif.c - Copyright (c) 1996, 1998 Red Hat, Inc. self.assertEqual(list(c), list(range(2,2000))) Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org -- Copyright (c) IBM Corporation, 2000, 2008. All rights reserved. -- Copyright (C) 2012 Christian Heimes (christian@python.org) fficommon.h - Copyright (c) 1996 Red Hat, Inc. Copyright (C) 2012-2016 Christian Heimes (christian@python.org) Copyright (C) 2004-2005 Gerhard Häring (c) 2002 Python Software Foundation. All Rights Reserved. Copyright (c) 2004 by Secret Labs AB, http://www.pythonware.com Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2006-2008, R Oudkerk Licensed to PSF under a Contributor Agreement. .. Copyright 1995 Virginia Polytechnic Institute and State University and Fred L. Drake, Jr. This copyright notice must be distributed on all copies, but this document otherwise may be distributed as part of the Python distribution. No fee may be charged for this document in any representation, either on paper or electronically. This restriction does not affect other elements in a distributed package in any way. Copyright 2012-2013 by Larry Hastings. Copyright (C) 2002-2006 Python Software Foundation Contact: email-sig@python.org email package unit tests for (optional) Asian codecs Copyright (c) 2002 Peter O'Gorman Copyright 2007 Google Inc. Copyright (c) 1999 by Fredrik Lundh Copyright (C) 2001-2010 Python Software Foundation Contact: email-sig@python.org email package unit tests Copyright (c) 2003-2004 by Fredrik Lundh. All rights reserved. Copyright (c) 1991-1999 Unicode, Inc. All Rights reserved. Copyright (c) 2000-2017 Expat development team Licensed under the MIT license: Copyright (c) 1997-2000 Thai Open Source Software Center Ltd Copyright (c) 2001-2020 Python Software Foundation.\n\ Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved.\ Copyright (c) 1998 The Open Group Copyright (c) 2000-2010, eGenix.com Software GmbH; mailto:info@egenix.com Copyright (C) 2005 Gerhard Häring copyright = '2001-%s, Python Software Foundation' % time.strftime('%Y') Copyright (c) 1996-2008 Red Hat, Inc and others. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a Contributor Agreement. Copyright (C) 1997, 2002, 2003, 2007, 2008 Martin von Loewis %VERSION%, (c) 2001-2019 Python Software Foundation. Copyright (c) 2013 Marek Majkowski Copyright (c) 2008 Daniel Amelang Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved. Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, All rights reserved. Copyright (c) 1998, 1999, 2000 Thai Open Source Software Center Ltd and Clark Cooper 2001-04-15 fl export copyright as Python attribute, not global 2001-04-28 fl added copy methods (work in progress) Copyright (C) 2002, 2003 Python Software Foundation. Copyright (c) 2004, 2005, 2006 Python Software Foundation. dnl Copyright © 2012-2015 Dan Nicholson Copyright (c) 1999-2009 by Secret Labs AB. All rights reserved. Copyright (c) 2003-2010 Python Software Foundation This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. Copyright (c) 1991-1995 Stichting Mathematisch Centrum. All rights reserved. 2001-10-18 fl fixed group reset issue (from Matthew Mueller) Copyright (C) 1992-1996, 1998-2012 Free Software Foundation, Inc. -- Copyright (c) IBM Corporation, 1981, 2008. All rights reserved. -- Copyright (C) 2000 Bastian Kleineidam ppc-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S Portions copyright 1991-1995 by Stichting Mathematisch Centrum, Amsterdam, The Netherlands. Copying is permitted under the terms associated with the main Python distribution, with the additional restriction that this additional notice be included and maintained on all distributed copies. A. HISTORY OF THE SOFTWARE ========================== Python was created in the early 1990s by Guido van Rossum at Stichting Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands as a successor of a language called ABC. Guido remains Python's principal author, although it includes many contributions from others. In 1995, Guido continued his work on Python at the Corporation for National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) in Reston, Virginia where he released several versions of the software. In May 2000, Guido and the Python core development team moved to BeOpen.com to form the BeOpen PythonLabs team. In October of the same year, the PythonLabs team moved to Digital Creations, which became Zope Corporation. In 2001, the Python Software Foundation (PSF, see https://www.python.org/psf/) was formed, a non-profit organization created specifically to own Python-related Intellectual Property. Zope Corporation was a sponsoring member of the PSF. All Python releases are Open Source (see http://www.opensource.org for the Open Source Definition). Historically, most, but not all, Python releases have also been GPL-compatible; the table below summarizes the various releases. Release Derived Year Owner GPL- from compatible? (1) 0.9.0 thru 1.2 1991-1995 CWI yes 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes 1.6 1.5.2 2000 CNRI no 2.0 1.6 2000 BeOpen.com no 1.6.1 1.6 2001 CNRI yes (2) 2.1 2.0+1.6.1 2001 PSF no 2.0.1 2.0+1.6.1 2001 PSF yes 2.1.1 2.1+2.0.1 2001 PSF yes 2.1.2 2.1.1 2002 PSF yes 2.1.3 2.1.2 2002 PSF yes 2.2 and above 2.1.1 2001-now PSF yes Footnotes: (1) GPL-compatible doesn't mean that we're distributing Python under the GPL. All Python licenses, unlike the GPL, let you distribute a modified version without making your changes open source. The GPL-compatible licenses make it possible to combine Python with other software that is released under the GPL; the others don't. (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, because its license has a choice of law clause. According to CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 is "not incompatible" with the GPL. Thanks to the many outside volunteers who have worked under Guido's direction to make these releases possible. B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON =============================================================== PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -------------------------------------------- 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"), and the Individual or Organization ("Licensee") accessing and otherwise using this software ("Python") in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python alone or in any derivative version, provided, however, that PSF's License Agreement and PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. 3. In the event Licensee prepares a derivative work that is based on or incorporates Python or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python. 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between PSF and Licensee. This License Agreement does not grant permission to use PSF trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By copying, installing or otherwise using Python, Licensee agrees to be bound by the terms and conditions of this License Agreement. BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 ------------------------------------------- BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the Individual or Organization ("Licensee") accessing and otherwise using this software in source or binary form and its associated documentation ("the Software"). 2. Subject to the terms and conditions of this BeOpen Python License Agreement, BeOpen hereby grants Licensee a non-exclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use the Software alone or in any derivative version, provided, however, that the BeOpen Python License is retained in the Software, alone or in any derivative version prepared by Licensee. 3. BeOpen is making the Software available to Licensee on an "AS IS" basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 5. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 6. This License Agreement shall be governed by and interpreted in all respects by the law of the State of California, excluding conflict of law provisions. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between BeOpen and Licensee. This License Agreement does not grant permission to use BeOpen trademarks or trade names in a trademark sense to endorse or promote products or services of Licensee, or any third party. As an exception, the "BeOpen Python" logos available at http://www.pythonlabs.com/logos.html may be used according to the permissions granted on that web page. 7. By copying, installing or otherwise using the software, Licensee agrees to be bound by the terms and conditions of this License Agreement. CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 --------------------------------------- 1. This LICENSE AGREEMENT is between the Corporation for National Research Initiatives, having an office at 1895 Preston White Drive, Reston, VA 20191 ("CNRI"), and the Individual or Organization ("Licensee") accessing and otherwise using Python 1.6.1 software in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, CNRI hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python 1.6.1 alone or in any derivative version, provided, however, that CNRI's License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet using the following URL: http://hdl.handle.net/1895.22/1013". 3. In the event Licensee prepares a derivative work that is based on or incorporates Python 1.6.1 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python 1.6.1. 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. This License Agreement shall be governed by the federal intellectual property law of the United States, including without limitation the federal copyright law, and, to the extent such U.S. federal law does not apply, by the law of the Commonwealth of Virginia, excluding Virginia's conflict of law provisions. Notwithstanding the foregoing, with regard to derivative works based on Python 1.6.1 that incorporate non-separable material that was previously distributed under the GNU General Public License (GPL), the law of the Commonwealth of Virginia shall govern this License Agreement only as to issues arising under or with respect to Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between CNRI and Licensee. This License Agreement does not grant permission to use CNRI trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By clicking on the "ACCEPT" button where indicated, or by copying, installing or otherwise using Python 1.6.1, Licensee agrees to be bound by the terms and conditions of this License Agreement. ACCEPT CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 -------------------------------------------------- Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Permission to use, copy, modify, and distribute this software and its documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of Stichting Mathematisch Centrum or CWI not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission. STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. Software: Python 3.9.2 Copyright notice: Copyright (c) 1999-2000 by Secret Labs AB Copyright (C) 2005-2007 Gregory P. Smith (greg@krypto.org) Copyright (c) 2003. . Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright 1996,1997 by Oliver Andrich, Koblenz, Germany. Copyright (c) 2008-2020 Stefan Krah. All rights reserved. Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives. All rights reserved. Copyright 1994 by Lance Ellinghouse Cathedral City, California Republic, United States of America. Copyright (C) 2001 Python Software Foundation Barry Warsaw , 2000. Copyright (c) 2008-2012 Stefan Krah. All rights reserved. 2001-07-01 fl added BIGCHARSET support (from Martin von Loewis) ``'Copyright 1991-1995 Stichting Mathematisch Centrum, Amsterdam'`` Copyright (C) 2003 Python Software Foundation Copyright (C) 2001-2016 Vinay Sajip. All Rights Reserved. Copyright 1995-1997, Automatrix, Inc., all rights reserved. Copyright (c) 2002 MyCompanyName. All rights reserved. Copyright (c) 2004 by Peter Astrand Copyright (c) 1999-2002 by Fredrik Lundh. Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.\n\ AIX ABI support (c) 2002 Free Software Foundation, Inc. Copyright (C) 1996-2020 Free Software Foundation, Inc. 2013-02-04 mrab added fullmatch primitive 2003-10-17 gn implemented non recursive scheme 2003-04-18 mvl fully support 4-byte codes Copyright (c) 1999-2002 by Secret Labs AB. portions copyright 2001, Autonomous Zones Industries, Inc., all rights... Copyright © 2013 W3C® (MIT, ERCIM, Keio, Beihang), All Rights Reserved. Copyright (C) 1986 Gary S. Brown. You may use this program, or code or tables extracted from it, as desired without restriction. -- Copyright (c) IBM Corporation, 2003, 2008. All rights reserved. -- ; Copyright (c) 2008-2020 Stefan Krah. All rights reserved. Copyright 2008 Armin Ronacher. Copyright © 2000 BeOpen.com. All rights reserved. (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) Copyright (c) 2005-2006 ActiveState Software Inc. Copyright (C) 1994 Steen Lumholt. Copyright (c) 1999 by Fredrik Lundh. libffi - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (C) 2002 Lars Gustaebel Copyright (c) 1999-2003 Steve Purcell Darwin ABI support (c) 2001 John Hornkvist Copyright (c) 2001-2021 Python Software Foundation.\n\ Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) Copyright (c) 1999-2008 by Fredrik Lundh i.e., "Copyright © 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 -- Copyright (c) IBM Corporation, 2005, 2009. All rights reserved. -- Copyright (c) 2001-2017 Expat maintainers Copyright (c) 2001-2012 Python Software Foundation. All Rights Reserved. Copyright (C) 2002-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org (c) 2002 Gregory P. Ward. All Rights Reserved. copyright, i.e., "Copyright © 2001-2021 Python Software Foundation; All Rights Reserved" are retained in Python |release| alone or in any derivative version prepared by Licensee. Copyright (c) 2000 BeOpen.com.\n\ Copyright (C) 2001-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org Copyright (C) 2003-2004 Federico Di Gregorio 2001-05-14 fl fixes for 1.5.2 compatibility Copyright (c) 2001-2021 Python Software Foundation. All rights reserved. Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved. Copyright (c) 2013 W3C(R) (MIT, ERCIM, Keio, Beihang), All Rights Reserved. Copyright (C) 1995, 1996, 1997, 1998, and 1999 WIDE Project. Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw, Thomas Wouters, Anthony Baxter Contact: email-sig@python.org Copyright (C) 2005-2010 Gregory P. Smith (greg@krypto.org) Copyright (c) 2000 Doug White, 2006 James Knight, 2007 Christian Heimes All rights reserved. Copyright (C) 1999-2001 Gregory P. Ward. Copyright (c) 1999-2002 by Fredrik Lundh + Copyright 2007 Python Software Foundation. else if (config == (void )2000 && (c) == 0x9B1D) { \ Copyright (c) 1999-2002 by Secret Labs AB 2002-11-09 fl fixed empty sub/subn return type Copyright (C) 2003-2013 Python Software Foundation import copy import operator import pickle import struct import unittest import plistlib import os import datetime import codecs import binascii import collections from test import support from io import BytesIO Copyright 2009 Gabriel A. Genellina Copyright (c) 2003-2009 by Fredrik Lundh. All rights reserved. Copyright 2004-2005 Elemental Security, Inc. All Rights Reserved. (c) Copyright Guido van Rossum, 2000. Copyright (C) 1995, 1996, 1997, and 1998 WIDE Project. Copyright (C) 2011-2012 Vinay Sajip. Copyright 2006 Google, Inc. All Rights Reserved. (c) Copyright Marc-Andre Lemburg, 2005. Copyright (C) YEAR ORGANIZATION FIRST AUTHOR , YEAR. Copyright (c) 1999-2000, Marc-Andre Lemburg; mailto:mal@lemburg.com Copyright (c) 1995-2000, Corporation for National Research Initiatives. Copyright (C) 2001 I'O, All Rights Reserved. Copyright (c) 1999 Toby Dickenson Copyright (C) 2001,2002 Python Software Foundation csv package unit tests Copyright (C) 2005, 2006 Martin von Löwis Licensed to PSF under a Contributor Agreement. Copyright (c) 1997 by Fredrik Lundh Copyright (c) 2002-2006 Python Software Foundation. All rights reserved. Copyright (c) 2002 Roger Sayle Copyright 1995-1996 by Fred L. Drake, Jr. and Virginia Polytechnic Institute and State University, Blacksburg, Virginia, USA. types.c - Copyright (c) 1996, 1998 Red Hat, Inc. Copyright 2000, Mojam Media, Inc., all rights reserved. Copyright (C) 1994 X Consortium Copyright (C) 2002-2004 Python Software Foundation Copyright (C) 2004-2006 Python Software Foundation Authors: Baxter, Wouters and Warsaw Contact: email-sig@python.org Copyright (c) 2002 Jorge Acereda & darwin.S - Copyright (c) 1996, 1998, 2001, 2002, 2003 Red Hat, Inc. Copyright © 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. "Copyright 1995-1996 by Virginia Polytechnic Institute & State\n\ x86-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (C) 1995-2011 Jean-loup Gailly and Mark Adler Copyright 1991-1995, Stichting Mathematisch Centrum, all rights reserved. (c) 2001-2020 Python Software Foundation. darwin64.S - Copyright (c) 2006 Free Software Foundation, Inc. Copyright (C) 2001-2006 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org Copyright (c) 2005 Don Owens All rights reserved. (c) Copyright 2005, Marc-Andre Lemburg (mal@lemburg.com). library/xml.etree.elementtree,,:include, Copyright (c) . ffi.c - Copyright (c) 1998 Geoffrey Keating Copyright 2006 Georg Brandl. Copyright (C) 2005-2010 Gerhard Häring (c) 2013-2017 Christian Heimes Copyright 1992-2018 Free Software Foundation, Inc. Copyright (c) 1990-1995, Stichting Mathematisch Centrum. i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021 Python Software Foundation; 2001-10-24 fl added finditer primitive (for 2.2 only) Copyright (C) 2001-2006 Python Software Foundation Author: Keith Dart Contact: email-sig@python.org Copyright (c) 1999-2009 by Fredrik Lundh. (c) 2000 Peter Bosch. All Rights Reserved. %version%, (c) 2001-2021 Python Software Foundation. Copyright (C) 2001 earthian@tama.or.jp, All Rights Reserved. Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. Copyright (C) 2012 Free Software Foundation, Inc. Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield Contact: email-sig@python.org Copyright 1994 by Lance Ellinghouse, Cathedral City, California Republic, United States of America. Copyright (C) 2006 - 2010 Gregor Lingl email: glingl@aon.at Copyright © 2001-2021 Python Software Foundation. All rights reserved. Copyright (c) 2009,2010 Zmanda Inc. Copyright (c) 1998-2008 The OpenSSL Project. All rights reserved. Copyright 1996 by Sam Rushing Copyright (c) 1998-2000 Thai Open Source Software Center Ltd and Clark Cooper copyright, i.e., "Copyright © 2001-2018 Python Software Foundation; All Rights Reserved" are retained in Python 3.9 alone or in any derivative version prepared by Licensee. (c) Craig Reese, Joe Campbell and Jeff Poskanzer 1989 / Copyright (c) 1999, 2000, 2001 Steve Purcell This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. -- Copyright (c) IBM Corporation, 2005, 2008. All rights reserved. -- Copyright (C) 2001-2012 Python Software Foundation. All Rights Reserved. dnl Copyright © 2004 Scott James Remnant . Copyright (C) 2002-2007 Python Software Foundation Contact: email-sig@python.org Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. Copyright (c) 2009,2010 Dustin J. Mitchell Copyright (c) 2002 Bo Thorsen 2001-10-21 fl added sub/subn primitive Copyright 1992-1994, David Gottner Copyright (C) 2011-2013 Vinay Sajip. Copyright (c) 1991, 2000, 2001 by Lucent Technologies. Copyright (c) 2010 Python Software Foundation. All Rights Reserved. " SRE 2.2.2 Copyright (c) 1997-2002 by Secret Labs AB "; ; Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2003-2005 by Peter Astrand Copyright (c) 1999 by Secret Labs AB libffi PyOBJC - Copyright (c) 1996-2003 Red Hat, Inc. Copyright (c) 1999-2009 by Fredrik Lundh Copyright 2007 Google, Inc. All Rights Reserved. -- Copyright (c) IBM Corporation, 2004, 2008. All rights reserved. -- Copyright (C) 2001-2007 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2000 BeOpen.com. All rights reserved. 3-2926 U+00A9 COPYRIGHT SIGN [2000] %version%, (c) 2001-2020 Python Software Foundation. Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved. Copyright (C) 2000 Luke Kenneth Casson Leighton Copyright (C) 2005-2007 Gerhard Häring 2001-12-07 fl fixed memory leak in sub/subn (Guido van Rossum) -- Copyright (c) IBM Corporation, 2001, 2008. All rights reserved. -- Virginia, USA. Portions copyright 1991-1995 by Stichting Mathematisch\n\ Copyright (c) 2004 Free Software Foundation, Inc. so portions are Copyright (C) 2001,2002 Python Software Foundation, and were written by Barry Warsaw. Copyright (C) 2004-2010 Gerhard Häring Copyright (c) 2004 Python Software Foundation. (c) Copyright 2000 Guido van Rossum. Copyright 2007 Georg Brandl. Copyright (c) 1999 by Secret Labs AB. Copyright (c) 2002 Unicode, Inc. All Rights reserved. Copyright 2009 Brian Quinlan. All Rights Reserved. Copyright (c) 2008-2009, Google Inc. Copyright (c) 2001-2006 Twisted Matrix Laboratories. (c) Copyright CNRI, All Rights Reserved. NO WARRANTY. License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet Copyright (c) Corporation for National Research Initiatives. if (config == (void )2000 && (c) == 0x20B9F) { \ Copyright 2001-2016 by Vinay Sajip. All Rights Reserved. ppc-ffitarget.h - Copyright (c) 1996-2003 Red Hat, Inc. ppc-darwin.h - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. Copyright (c) 2001-2006 Gregory P. Ward. All rights reserved. Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. Copyright 2000 by Timothy O'Malley Copyright (C) 2007-2012 Michael Foord & the mock team E-mail: fuzzyman AT voidspace DOT org DOT uk Copyright (C) 2011-2014 Vinay Sajip. x86-ffi64.c - Copyright (c) 2002 Bo Thorsen ppc64-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S Copyright (c) 2002 Ranjit Mathew Copyright (C) 2001-2006 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (c) 2004, Outercurve Foundation. Copyright (c) 1995-2001 Corporation for National Research Initiatives.\n\ Copyright 1999, Bioreason, Inc., all rights reserved. 2001-10-20 fl added split primitive; re-enable unicode for 1.6/2.0/2.1 Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. Copyright (c) 2008 by Christian Heimes Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a contributor agreement. ppc-darwin.S - Copyright (c) 2000 John Hornkvist Copyright (c) 2000, BeOpen.com. Copyright (C) 2001-2010 Python Software Foundation Author: Barry Warsaw Contact: email-sig@python.org Copyright (C) 2001-2007 Python Software Foundation Author: Anthony Baxter Contact: email-sig@python.org Copyright (c) 2004 by Fredrik Lundh Copyright Disney Enterprises, Inc. All Rights Reserved. ffi.c - Copyright (c) 1996, 1998, 1999, 2001 Red Hat, Inc. Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Copyright (c) . prepcif.c - Copyright (c) 1996, 1998 Red Hat, Inc. self.assertEqual(list(c), list(range(2,2000))) Copyright (C) 2002-2007 Python Software Foundation Author: Ben Gertzfield, Barry Warsaw Contact: email-sig@python.org -- Copyright (c) IBM Corporation, 2000, 2008. All rights reserved. -- Copyright (C) 2012 Christian Heimes (christian@python.org) fficommon.h - Copyright (c) 1996 Red Hat, Inc. Copyright (C) 2012-2016 Christian Heimes (christian@python.org) Copyright (C) 2004-2005 Gerhard Häring (c) 2002 Python Software Foundation. All Rights Reserved. Copyright (c) 2004 by Secret Labs AB, http://www.pythonware.com Copyright (c) 2004, Outercurve Foundation. Copyright (c) 2006-2008, R Oudkerk Licensed to PSF under a Contributor Agreement. .. Copyright 1995 Virginia Polytechnic Institute and State University and Fred L. Drake, Jr. This copyright notice must be distributed on all copies, but this document otherwise may be distributed as part of the Python distribution. No fee may be charged for this document in any representation, either on paper or electronically. This restriction does not affect other elements in a distributed package in any way. Copyright 2012-2013 by Larry Hastings. Copyright (C) 2002-2006 Python Software Foundation Contact: email-sig@python.org email package unit tests for (optional) Asian codecs Copyright (c) 2002 Peter O'Gorman Copyright 2007 Google Inc. Copyright (c) 1999 by Fredrik Lundh Copyright (C) 2001-2010 Python Software Foundation Contact: email-sig@python.org email package unit tests Copyright (c) 2003-2004 by Fredrik Lundh. All rights reserved. Copyright (c) 1991-1999 Unicode, Inc. All Rights reserved. Copyright (c) 2000-2017 Expat development team Licensed under the MIT license: Copyright (c) 1997-2000 Thai Open Source Software Center Ltd Copyright (c) 1998 The Open Group Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved.\ Copyright (c) 2000-2010, eGenix.com Software GmbH; mailto:info@egenix.com Copyright (C) 2005 Gerhard Häring copyright = '2001-%s, Python Software Foundation' % time.strftime('%Y') Copyright (c) 1996-2008 Red Hat, Inc and others. Copyright (C) 2005 Martin v. Löwis Licensed to PSF under a Contributor Agreement. Copyright (C) 1997, 2002, 2003, 2007, 2008 Martin von Loewis %VERSION%, (c) 2001-2019 Python Software Foundation. ( Copyright (c) 2011-2020 Stefan Krah. All rights reserved. ) Copyright (c) 2013 Marek Majkowski Copyright (c) 2008 Daniel Amelang Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved. Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, All rights reserved. Copyright (c) 1998, 1999, 2000 Thai Open Source Software Center Ltd and Clark Cooper 2001-04-15 fl export copyright as Python attribute, not global 2001-04-28 fl added copy methods (work in progress) Copyright (C) 2002, 2003 Python Software Foundation. Copyright (c) 2004, 2005, 2006 Python Software Foundation. dnl Copyright © 2012-2015 Dan Nicholson Copyright (c) 1999-2009 by Secret Labs AB. All rights reserved. Copyright (c) 2003-2010 Python Software Foundation This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. Copyright (c) 1991-1995 Stichting Mathematisch Centrum. All rights reserved. 2001-10-18 fl fixed group reset issue (from Matthew Mueller) Copyright (C) 1992-1996, 1998-2012 Free Software Foundation, Inc. -- Copyright (c) IBM Corporation, 1981, 2008. All rights reserved. -- Copyright (C) 2000 Bastian Kleineidam ppc-darwinclosure.S - Copyright (c) 2002, 2003, 2004, Free Software Foundation, Inc. based on ppcclosure.S Portions copyright 1991-1995 by Stichting Mathematisch Centrum, Amsterdam, The Netherlands. Copying is permitted under the terms associated with the main Python distribution, with the additional restriction that this additional notice be included and maintained on all distributed copies. A. HISTORY OF THE SOFTWARE ========================== Python was created in the early 1990s by Guido van Rossum at Stichting Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands as a successor of a language called ABC. Guido remains Python's principal author, although it includes many contributions from others. In 1995, Guido continued his work on Python at the Corporation for National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) in Reston, Virginia where he released several versions of the software. In May 2000, Guido and the Python core development team moved to BeOpen.com to form the BeOpen PythonLabs team. In October of the same year, the PythonLabs team moved to Digital Creations, which became Zope Corporation. In 2001, the Python Software Foundation (PSF, see https://www.python.org/psf/) was formed, a non-profit organization created specifically to own Python-related Intellectual Property. Zope Corporation was a sponsoring member of the PSF. All Python releases are Open Source (see http://www.opensource.org for the Open Source Definition). Historically, most, but not all, Python releases have also been GPL-compatible; the table below summarizes the various releases. Release Derived Year Owner GPL- from compatible? (1) 0.9.0 thru 1.2 1991-1995 CWI yes 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes 1.6 1.5.2 2000 CNRI no 2.0 1.6 2000 BeOpen.com no 1.6.1 1.6 2001 CNRI yes (2) 2.1 2.0+1.6.1 2001 PSF no 2.0.1 2.0+1.6.1 2001 PSF yes 2.1.1 2.1+2.0.1 2001 PSF yes 2.1.2 2.1.1 2002 PSF yes 2.1.3 2.1.2 2002 PSF yes 2.2 and above 2.1.1 2001-now PSF yes Footnotes: (1) GPL-compatible doesn't mean that we're distributing Python under the GPL. All Python licenses, unlike the GPL, let you distribute a modified version without making your changes open source. The GPL-compatible licenses make it possible to combine Python with other software that is released under the GPL; the others don't. (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, because its license has a choice of law clause. According to CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 is "not incompatible" with the GPL. Thanks to the many outside volunteers who have worked under Guido's direction to make these releases possible. B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON =============================================================== Python software and documentation are licensed under the Python Software Foundation License Version 2. Starting with Python 3.8.6, examples, recipes, and other code in the documentation are dual licensed under the PSF License Version 2 and the Zero-Clause BSD license. Some software incorporated into Python is under different licenses. The licenses are listed with code falling under that license. PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -------------------------------------------- 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"), and the Individual or Organization ("Licensee") accessing and otherwise using this software ("Python") in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python alone or in any derivative version, provided, however, that PSF's License Agreement and PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. 3. In the event Licensee prepares a derivative work that is based on or incorporates Python or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python. 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between PSF and Licensee. This License Agreement does not grant permission to use PSF trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By copying, installing or otherwise using Python, Licensee agrees to be bound by the terms and conditions of this License Agreement. BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 ------------------------------------------- BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the Individual or Organization ("Licensee") accessing and otherwise using this software in source or binary form and its associated documentation ("the Software"). 2. Subject to the terms and conditions of this BeOpen Python License Agreement, BeOpen hereby grants Licensee a non-exclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use the Software alone or in any derivative version, provided, however, that the BeOpen Python License is retained in the Software, alone or in any derivative version prepared by Licensee. 3. BeOpen is making the Software available to Licensee on an "AS IS" basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 5. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 6. This License Agreement shall be governed by and interpreted in all respects by the law of the State of California, excluding conflict of law provisions. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between BeOpen and Licensee. This License Agreement does not grant permission to use BeOpen trademarks or trade names in a trademark sense to endorse or promote products or services of Licensee, or any third party. As an exception, the "BeOpen Python" logos available at http://www.pythonlabs.com/logos.html may be used according to the permissions granted on that web page. 7. By copying, installing or otherwise using the software, Licensee agrees to be bound by the terms and conditions of this License Agreement. CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 --------------------------------------- 1. This LICENSE AGREEMENT is between the Corporation for National Research Initiatives, having an office at 1895 Preston White Drive, Reston, VA 20191 ("CNRI"), and the Individual or Organization ("Licensee") accessing and otherwise using Python 1.6.1 software in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, CNRI hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python 1.6.1 alone or in any derivative version, provided, however, that CNRI's License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) 1995-2001 Corporation for National Research Initiatives; All Rights Reserved" are retained in Python 1.6.1 alone or in any derivative version prepared by Licensee. Alternately, in lieu of CNRI's License Agreement, Licensee may substitute the following text (omitting the quotes): "Python 1.6.1 is made available subject to the terms and conditions in CNRI's License Agreement. This Agreement together with Python 1.6.1 may be located on the Internet using the following unique, persistent identifier (known as a handle): 1895.22/1013. This Agreement may also be obtained from a proxy server on the Internet using the following URL: http://hdl.handle.net/1895.22/1013". 3. In the event Licensee prepares a derivative work that is based on or incorporates Python 1.6.1 or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python 1.6.1. 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. This License Agreement shall be governed by the federal intellectual property law of the United States, including without limitation the federal copyright law, and, to the extent such U.S. federal law does not apply, by the law of the Commonwealth of Virginia, excluding Virginia's conflict of law provisions. Notwithstanding the foregoing, with regard to derivative works based on Python 1.6.1 that incorporate non-separable material that was previously distributed under the GNU General Public License (GPL), the law of the Commonwealth of Virginia shall govern this License Agreement only as to issues arising under or with respect to Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between CNRI and Licensee. This License Agreement does not grant permission to use CNRI trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By clicking on the "ACCEPT" button where indicated, or by copying, installing or otherwise using Python 1.6.1, Licensee agrees to be bound by the terms and conditions of this License Agreement. ACCEPT CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 -------------------------------------------------- Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, The Netherlands. All rights reserved. Permission to use, copy, modify, and distribute this software and its documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of Stichting Mathematisch Centrum or CWI not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission. STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION ---------------------------------------------------------------------- Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. Software: zlib 1.2.11 Copyright notice: Copyright (C) 1995-2011, 2016 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1998,1999,2000 by Jacques Nomssi Nzali. echo 'pragma comment(copyright, "Copyright (C) 1995-2017 Jean-Loup Gailly, Mark Adler. OS/400 version by P. Monnerat.")' >> os400.c makemodule OS400 os400.c LINK= No need to rebuild service program yet. Copyright (C) 2007-2008 Even Rouault Copyright (C) 2003 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2003 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2009-2010 Mathias Svensson ( http:result42.com ) Copyright (C) 1995-2003 by Jean-loup Gailly. Copyright (C) 1998-2005 Gilles Vollant © Copyright Henrik Ravn 2004 Copyright (C) 2003, 2005, 2008, 2010, 2012 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Version 1.7 12 August 2012 Mark Adler / Copyright (C) 1995-1998 Jean-loup Gailly. Copyright (C) 1995-2003, 2010 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (c) 2004, 2005 Mark Adler. Copyright (C) 1995-2005, 2010 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1998 - 2010 Gilles Vollant, Even Rouault, Mathias Svensson -- Copyright (C) 2002-2004 Dmitriy Anisimkov -- " inflate9 1.2.11 Copyright 1995-2017 Mark Adler "; Copyright (C) 2004, 2008, 2012, 2016 Mark Adler, all rights reserved For conditions of distribution and use, see copyright notice in gzlog.h version 2.2, 14 Aug 2012 VALUE "LegalCopyright", "(C) 1995-2017 Jean-loup Gailly & Mark Adler\0" ; Copyright (C) 1995-2003 Mark Adler ; For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2008 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2003 by Cosmin Truta. [assembly: AssemblyCopyright("(c) 2004 by Henrik Ravn")] " unzip 1.01 Copyright 1998-2004 Gilles Vollant - http:www.winimage.com/zLibDll"; Copyright (C) 1995-2006, 2011, 2016 Jean-loup Gailly For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler .LP This software is provided 'as-is', without any express or implied warranty. In no event will the authors be held liable for any damages arising from the use of this software. Copyright (c) 1997 Christian Michelsen Research AS Advanced Computing Fantoftvegen 38, 5036 BERGEN, Norway Copyright (C) 2004, 2010 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h "gzappend 1.2 (11 Oct 2012) Copyright (C) 2003, 2012 Mark Adler\n" Copyright (C) 1995-2003 Jean-loup Gailly and Mark Adler. Copyright (C) 2002-2013 Mark Adler, all rights reserved version 2.3, 21 Jan 2013 Copyright (C) 1998-2010 Gilles Vollant (minizip) ( http:www.winimage.com/zLibDll/minizip.html ) Copyright (C) 2004, 2008, 2012 Mark Adler, all rights reserved version 2.2, 14 Aug 2012 fprintf(stderr, "Copyright (C) 2003-2010 Mark Adler\n"); " inflate 1.2.11 Copyright 1995-2017 Mark Adler "; ; Copyright (C) 1995-2010 Jean-loup Gailly, Brian Raiter and Gilles Vollant. Copyright (C) 1995-2006, 2010, 2011, 2016 Jean-loup Gailly For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1998, 2007 Brian Raiter VALUE "LegalCopyright", "(C) 1995-2017 Jean-loup Gailly & Mark Adler\0" Copyright (C) 2004, 2005, 2012 Mark Adler, all rights reserved version 1.2, 14 Aug 2012 Copyright (C) 1995-2017 Jean-loup Gailly, Mark Adler For conditions of distribution and use, see copyright notice in zlib.h const char zipcopyright[] =" zip 1.01 Copyright 1998-2004 Gilles Vollant - http:www.winimage.com/zLibDll"; Copyright (C) 1998 by Bob Dellaca. ;;; Copyright (C) 1998 Brian Raiter Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2017 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2005, 2012 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Version 1.1 29 Sep 2012 Mark Adler / Copyright (C) 1995-2003, 2010, 2014, 2016 Jean-loup Gailly, Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2017 Jean-loup Gailly detectdatatype() function provided freely by Cosmin Truta, 2006 For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1998 by Jacques Nomssi Nzali. MiniZip - Copyright (c) 1998-2010 - by Gilles Vollant - version 1.1 64 bits from Mathias Svensson Copyright (c) Henrik Ravn 2004 Copyright (C) 1995-2005, 2014, 2016 Jean-loup Gailly, Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2004-2017 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 1995-2003 Jean-loup Gailly. Copyright (c) 1996 L. Peter Deutsch Copyright (C) 2003 Chris Anderson Copyright (C) 1995-2017 Jean-loup Gailly For conditions of distribution and use, see copyright notice in zlib.h " deflate 1.2.11 Copyright 1995-2017 Jean-loup Gailly and Mark Adler "; ; Copyright (C) 1995-1996 Jean-loup Gailly, Brian Raiter and Gilles Vollant. Copyright (C) 2003 Cosmin Truta. Copyright (C) 2003, 2012 Mark Adler, all rights reserved version 1.2, 11 Oct 2012 Copyright (c) 1996 L. Peter Deutsch and Jean-Loup Gailly -- Copyright (C) 2002-2004 Dmitriy Anisimkov -- Copyright (C) 1998 by Andreas R. Kleinert Copyright (C) 1995-2016 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2011, 2016 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013, 2016 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h -- Copyright (C) 2002-2003 Dmitriy Anisimkov -- (C) 1995-2017 Jean-loup Gailly and Mark Adler Copyright (C) 2003, 2012, 2013 Mark Adler version 1.3, 24 Aug 2013 Copyright (C) 2002-2013 Mark Adler For conditions of distribution and use, see copyright notice in puff.h version 2.3, 21 Jan 2013 Copyright (C) 1995-2016 Jean-loup Gailly For conditions of distribution and use, see copyright notice in zlib.h Copyright (c) 2004, 2005 by Mark Adler
Last modified 11 December 2005
; Copyright (C) 2003 Chris Anderson Copyright (C) 1995-2016 Jean-loup Gailly, Mark Adler For conditions of distribution and use, see copyright notice in zlib.h { Copyright (c) 1997,99 Borland Corporation } Copyright (C) 1995-2006, 2010, 2011, 2012, 2016 Mark Adler For conditions of distribution and use, see copyright notice in zlib.h Copyright (C) 2003, 2012, 2013 Mark Adler For conditions of distribution and use, see copyright notice in blast.h version 1.3, 24 Aug 2013 Copyright (C) 2007, 2008, 2012 Mark Adler Version 1.4 18 August 2012 Mark Adler Copyright (c) 1990-2000 Info-ZIP. All rights reserved. Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler This software is provided 'as-is', without any express or implied warranty. In no event will the authors be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. Software: openssl 1.1.0 Copyright notice: Copyright 2000-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright (C) 1989, 1991 Free Software Foundation, Inc. Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) Copyright (C) 2006, Network Resonance, Inc. Copyright (C) 2011, RTFM, Inc. Copyright (C) 2017 National Security Research Institute. All Rights Reserved. Copyright (c) 1995-1998 Eric A. Young, Tim J. Hudson Copyright (c) 1998-2019 The OpenSSL Project. Copyright (c) 1998-2021 The OpenSSL Project Copyright (c) 2002 The OpenTSA Project. Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved Copyright (c) 2004 Kungliga Tekniska Högskolan (Royal Institute of Technology, Stockholm, Sweden). Copyright (c) 2004, 2018, Richard Levitte richard@levitte.org Copyright (c) 2004, EdelKey Project. All Rights Reserved. Copyright (c) 2004, Richard Levitte richard@levitte.org Copyright (c) 2007 KISA(Korea Information Security Agency). Copyright (c) 2008 Andy Polyakov appro@openssl.org Copyright (c) 2012, Intel Corporation. All Rights Reserved. Copyright (c) 2012-2014 Daniel J. Bernstein Copyright (c) 2012-2016 Jean-Philippe Aumasson Copyright (c) 2013-2014 Timo Teräs timo.teras@gmail.com Copyright (c) 2014, Intel Corporation. All Rights Reserved. Copyright (c) 2015 CloudFlare, Inc. Copyright (c) 2015, CloudFlare, Inc. Copyright (c) 2016 Viktor Dukhovni openssl-users@dukhovni.org. Copyright (c) 2017 National Security Research Institute. Copyright (c) 2017, Oracle and/or its affiliates. Copyright (c) 2018, Oracle and/or its affiliates. Copyright 1995-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 1995-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2001 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 1998-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 1999-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2000-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2001-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2002-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2002-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2002-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2002-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2002-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2003-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2003-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2003-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2003-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2003-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2014, Akamai Technologies. All Rights Reserved. Copyright 2004-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2004-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005 Nokia. Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2005-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006 NTT (Nippon Telegraph and Telephone Corporation) . Copyright 2006-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2006-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2007-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2007-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2007-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2008-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2008-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2008-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2008-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2008-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2009-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2009-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2009-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2009-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2010-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2010-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2010-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2010-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011 Google Inc. Copyright 2011-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2011-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012, Samuel Neves sneves@dei.uc.pt Copyright 2012-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2012-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2013 M. J. Dominus. Copyright 2013 Mark Jason Dominus Copyright 2013-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2013-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2013-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2013-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2013-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2014 Cryptography Research, Inc. Copyright 2014-2016 Cryptography Research, Inc. Copyright 2014-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2014-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2014-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2014-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2014-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015 Cryptography Research, Inc. Copyright 2015-2016 Cryptography Research, Inc. Copyright 2015-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2015-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016 Cryptography Research, Inc. Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016 VMS Software, Inc. All Rights Reserved. Copyright 2016-2016 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2016-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2017 BaishanCloud. Copyright 2017 Ribose Inc. All Rights Reserved. Copyright 2017 The OpenSSL Project Authors. All Rights Reserved. Copyright 2017 Ribose Inc.. All Rights Reserved. Copyright 2017-2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2017-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2017-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2017-2021 The OpenSSL Project Authors. All Rights Reserved. Copyright 2018 The OpenSSL Project Authors. All Rights Reserved. Copyright 2018-2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2018-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2019 The OpenSSL Project Authors. All Rights Reserved. Copyright 2019-2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 2020 The OpenSSL Project Authors. All Rights Reserved. Copyright 20xx-20yy The OpenSSL Project Authors. All Rights Reserved. Copyright Patrick Powell 1995 This code is based on code written by Patrick Powell papowell@astart.com Copyright 1998-2021 The OpenSSL Authors. All rights reserved. License: Apache License 2.0 Please see above. ================================================ FILE: build.sh ================================================ #!/bin/bash set -e PROJECTPATH=$(cd "$(dirname $0)"; pwd) export BUILD_PATH="${PROJECTPATH}/build/" # print usage message usage() { echo "Usage:" echo " bash build.sh [-j[n]] [-d] [-S on|off] " echo " bash build.sh -p {mindspore_shared_lib}] [-j[n]] [-d] [-S on|off] " echo " bash build.sh -e gpu|ascend [-V 9.2|10.1|310|910] [-j[n]] [-d] [-S on|off] " echo " bash build.sh -t on [-j[n]] [-d] [-S on|off] " echo "" echo "Options:" echo " -p {mindspore_shared_lib}, Use header files related to MindSpore(libmindspore.so) or Lite lib(libmindspore-lite.so)" echo " -e gpu|ascend, build MindSpore gpu or ascend whl package meanwhile" echo " -V Specify the device version, if -e gpu, default CUDA 10.1, if -e ascend, default Ascend 910" echo " -j[n] Set the threads when building (Default: -j8)" echo " -d Debug model" echo " -t Build testcases." echo " -S Enable enable download cmake compile dependency from gitee instead of github, default off" } # check value of input is 'on' or 'off' # usage: check_on_off arg_value arg_name check_on_off() { if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then echo "Invalid value $1 for option -$2" usage exit 1 fi } # check and set options checkopts() { # Init default values of build options THREAD_NUM=8 VERBOSE="" DEBUG_MODE="off" ENABLE_COVERAGE="off" ENABLE_ASAN="off" ENABLE_PYTHON="on" MS_WHL_LIB_PATH="" MS_BACKEND="" MS_BACKEND_HEADER="on" MS_VERSION="" RUN_TESTCASES="off" ENABLE_GITEE="off" # Process the options while getopts 'dvc:j:a:p:e:V:t:S:' opt do LOW_OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in e) echo "user opt: -e"${LOW_OPTARG} if [[ "$OPTARG" != "" ]]; then MS_BACKEND=$OPTARG fi ;; V) echo "user opt: -V"${LOW_OPTARG} if [[ "$OPTARG" != "" ]]; then MS_VERSION=$OPTARG fi ;; p) if [[ "$OPTARG" != "" ]]; then MS_WHL_LIB_PATH=$OPTARG MS_BACKEND_HEADER="off" else echo "Invalid value ${LOW_OPTARG} for option -p" usage exit 1 fi ;; d) echo "user opt: -d"${LOW_OPTARG} DEBUG_MODE="on" ;; j) echo "user opt: -j"${LOW_OPTARG} THREAD_NUM=$OPTARG ;; v) echo "user opt: -v"${LOW_OPTARG} VERBOSE="VERBOSE=1" ;; c) check_on_off $OPTARG c ENABLE_COVERAGE="$OPTARG" ;; a) check_on_off $OPTARG a ENABLE_ASAN="$OPTARG" ;; t) echo "user opt: -t"${LOW_OPTARG} RUN_TESTCASES="$OPTARG" MS_BACKEND_HEADER="off" ;; S) check_on_off $OPTARG S ENABLE_GITEE="$OPTARG" echo "enable download from gitee" ;; *) echo "Unknown option ${opt}!" usage exit 1 esac done } checkopts "$@" echo "---------------- MindSpore Serving: build start ----------------" mkdir -pv "${BUILD_PATH}/package/mindspore_serving/lib" if [[ "$MS_BACKEND_HEADER" != "off" ]]; then git submodule update --init third_party/mindspore fi # Create building path build_mindspore_serving() { echo "start build mindspore_serving project." mkdir -pv "${BUILD_PATH}/mindspore_serving" cd "${BUILD_PATH}/mindspore_serving" CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH" CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PYTHON=${ENABLE_PYTHON}" CMAKE_ARGS="${CMAKE_ARGS} -DTHREAD_NUM=${THREAD_NUM}" if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON" fi if [[ "X$ENABLE_ASAN" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON" fi if [[ "$MS_BACKEND" != "" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DMS_BACKEND=${MS_BACKEND}" fi if [[ "$MS_WHL_LIB_PATH" != "" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DMS_WHL_LIB_PATH=${MS_WHL_LIB_PATH}" fi if [[ "$MS_BACKEND_HEADER" != "off" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DMS_BACKEND_HEADER=${MS_BACKEND_HEADER}" fi if [[ "$MS_VERSION" != "" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DMS_VERSION=${MS_VERSION}" fi if [[ "X$RUN_TESTCASES" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TESTCASES=ON" fi if [[ "X$ENABLE_GITEE" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" fi echo "${CMAKE_ARGS}" cmake ${CMAKE_ARGS} ../.. if [[ -n "$VERBOSE" ]]; then CMAKE_VERBOSE="--verbose" fi cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM echo "success building mindspore_serving project!" } build_mindspore_serving echo "---------------- mindspore_serving: build end ----------------" ================================================ FILE: cmake/check_requirements.cmake ================================================ ## define customized find functions, print customized error messages function(find_required_package pkg_name) find_package(${pkg_name}) if(NOT ${pkg_name}_FOUND) message(FATAL_ERROR "Required package ${pkg_name} not found, please install the package and try" " building mindspore_serving again.") endif() endfunction() ## find python, quit if the found python is static set(Python3_USE_STATIC_LIBS FALSE) set(Python3_FIND_VIRTUALENV ONLY) find_package(Python3 COMPONENTS Interpreter Development) if(Python3_FOUND) message("Python3 found, version: ${Python3_VERSION}") message("Python3 library path: ${Python3_LIBRARY}") message("Python3 interpreter: ${Python3_EXECUTABLE}") elseif(Python3_LIBRARY AND Python3_EXECUTABLE AND ${Python3_VERSION} VERSION_GREATER_EQUAL "3.7.0" AND ${Python3_VERSION} VERSION_LESS "3.10.0") message(WARNING "Maybe python3 environment is broken.") message("Python3 library path: ${Python3_LIBRARY}") message("Python3 interpreter: ${Python3_EXECUTABLE}") else() message(FATAL_ERROR "Python3 not found, please install Python>=3.7.5, and set --enable-shared " "if you are building Python locally") endif() ## packages used both on windows and linux if(DEFINED ENV{MS_PATCH_PATH}) find_program(Patch_EXECUTABLE patch PATHS $ENV{MS_PATCH_PATH}) set(Patch_FOUND ${Patch_EXECUTABLE}) else() find_package(Patch) endif() if(NOT Patch_FOUND) message(FATAL_ERROR "Patch not found, please set environment variable MS_PATCH_PATH to path where Patch is located," " usually found in GIT_PATH/usr/bin on Windows") endif() message(PATCH_EXECUTABLE = ${Patch_EXECUTABLE}) find_required_package(Threads) ================================================ FILE: cmake/dependency_ms.cmake ================================================ # Compile MindSpore message(STATUS "**********begin to compile MindSpore**********") set(MS_SOURCE_DIR ${CMAKE_SOURCE_DIR}/third_party/mindspore) message(STATUS "MindSpore dir: ${MS_SOURCE_DIR}") message(STATUS "MindSpore compile method: -e${MS_BACKEND}") message(STATUS "MindSpore compile thread num: -j${THREAD_NUM}") message(STATUS "MindSpore version: -V${MS_VERSION}") if(MS_VERSION) set(MS_VERSION_OPTION -V${MS_VERSION}) endif() set(EXEC_COMMAND bash ${MS_SOURCE_DIR}/build.sh -e${MS_BACKEND} ${MS_VERSION_OPTION} -j${THREAD_NUM}) execute_process( COMMAND ${EXEC_COMMAND} WORKING_DIRECTORY ${MS_SOURCE_DIR} RESULT_VARIABLE RESULT ) if(NOT RESULT EQUAL "0") message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${MS_SOURCE_DIR}") endif() message(STATUS "**********end to compile MindSpore**********") ================================================ FILE: cmake/dependency_securec.cmake ================================================ # securec library # # # SECUREC_LIBRARY # if(NOT TARGET securec) set(_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE}) set(_ms_tmp_CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) set(CMAKE_C_FLAGS "${SECURE_CXX_FLAGS}") if(CMAKE_SYSTEM_NAME MATCHES "Windows") add_compile_definitions(SECUREC_ONLY_DECLARE_MEMSET) endif() add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec ${CMAKE_BINARY_DIR}/securec) set(CMAKE_POSITION_INDEPENDENT_CODE ${_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE}) set(CMAKE_C_FLAGS ${_ms_tmp_CMAKE_C_FLAGS}) endif() include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec/include) set(SECUREC_LIBRARY securec) ================================================ FILE: cmake/dependency_utils.cmake ================================================ # MS Utils # function(find_python_package out_inc out_lib) # Use PYTHON_EXECUTABLE if it is defined, otherwise default to python if("${PYTHON_EXECUTABLE}" STREQUAL "") set(PYTHON_EXECUTABLE "python3") else() set(PYTHON_EXECUTABLE "${PYTHON_EXECUTABLE}") endif() execute_process( COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())" RESULT_VARIABLE result OUTPUT_VARIABLE inc) string(STRIP "${inc}" inc) set(${out_inc} ${inc} PARENT_SCOPE) execute_process( COMMAND "${PYTHON_EXECUTABLE}" -c "import distutils.sysconfig as sysconfig; import os; \ print(os.path.join(sysconfig.get_config_var('LIBDIR'), sysconfig.get_config_var('LDLIBRARY')))" RESULT_VARIABLE result OUTPUT_VARIABLE lib) string(STRIP "${lib}" lib) set(${out_lib} ${lib} PARENT_SCOPE) endfunction() ================================================ FILE: cmake/external_libs/absl.cmake ================================================ if(ENABLE_GITEE_EULER) set(GIT_REPOSITORY "https://gitee.com/src-openeuler/abseil-cpp.git") set(GIT_TAG "openEuler-22.03-LTS") set(SHA256 "365b1ecbbcd81b4c58101808a8a28a3cf9ad7f9d05c08080a35c0d4283a44afa") set(ABSL_SRC "${CMAKE_SOURCE_DIR}/build/mindspore_serving/_deps/absl-src") __download_pkg_with_git(absl ${GIT_REPOSITORY} ${GIT_TAG} ${SHA256}) execute_process(COMMAND tar -xf ${ABSL_SRC}/abseil-cpp-20210324.2.tar.gz --strip-components 1 -C ${ABSL_SRC}) else() if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/abseil-cpp/repository/archive/20210324.2.tar.gz") set(SHA256 "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f") else() set(REQ_URL "https://github.com/abseil/abseil-cpp/archive/20210324.2.tar.gz") set(SHA256 "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f") endif() endif() if(NOT ENABLE_GLIBCXX) set(absl_CXXFLAGS "${absl_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() mindspore_add_pkg(absl VER 20210324.2 LIBS absl_strings absl_throw_delegate absl_raw_logging_internal absl_int128 absl_bad_optional_access URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=TRUE -DCMAKE_CXX_STANDARD=11 ) include_directories(${absl_INC}) add_library(mindspore_serving::absl_strings ALIAS absl::absl_strings) add_library(mindspore_serving::absl_throw_delegate ALIAS absl::absl_throw_delegate) add_library(mindspore_serving::absl_raw_logging_internal ALIAS absl::absl_raw_logging_internal) add_library(mindspore_serving::absl_int128 ALIAS absl::absl_int128) add_library(mindspore_serving::absl_bad_optional_access ALIAS absl::absl_bad_optional_access) ================================================ FILE: cmake/external_libs/c-ares.cmake ================================================ if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/c-ares/repository/archive/cares-1_15_0.tar.gz") set(SHA256 "7deb7872cbd876c29036d5f37e30c4cbc3cc068d59d8b749ef85bb0736649f04") else() set(REQ_URL "https://github.com/c-ares/c-ares/releases/download/cares-1_15_0/c-ares-1.15.0.tar.gz") set(SHA256 "6cdb97871f2930530c97deb7cf5c8fa4be5a0b02c7cea6e7c7667672a39d6852") endif() mindspore_add_pkg(c-ares VER 1.15.0 LIBS cares URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DCARES_SHARED:BOOL=OFF -DCARES_STATIC:BOOL=ON -DCARES_STATIC_PIC:BOOL=ON -DHAVE_LIBNSL:BOOL=OFF PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/c-ares/CVE-2021-3672.patch) include_directories(${c-ares_INC}) add_library(mindspore_serving::cares ALIAS c-ares::cares) ================================================ FILE: cmake/external_libs/eigen.cmake ================================================ set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(REQ_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz") set(SHA256 "8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72") mindspore_add_pkg(Eigen3 VER 3.4.0 URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DBUILD_TESTING=OFF) find_package(Eigen3 3.4.0 REQUIRED ${MS_FIND_NO_DEFAULT_PATH}) include_directories(${Eigen3_INC}) include_directories(${EIGEN3_INCLUDE_DIR}) set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE) add_library(mindspore_serving::eigen ALIAS Eigen3::Eigen) ================================================ FILE: cmake/external_libs/glog.cmake ================================================ set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_serving_private") set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001) set(glog_lib mindspore_serving_glog) if(NOT ENABLE_GLIBCXX) set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. set(REQ_URL "https://gitee.com/mirrors/glog/repository/archive/v0.4.0.tar.gz") set(SHA256 "e17cd4bb7c06951a12fc9db5130ec63a9f090b84340b8556fa0d530f73c6b634") else() set(REQ_URL "https://github.com/google/glog/archive/v0.4.0.tar.gz") set(SHA256 "f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c") endif() set(glog_option -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DWITH_GFLAGS=OFF -DCMAKE_BUILD_TYPE=Release) if(WIN32 AND NOT MSVC) if(CMAKE_SIZEOF_VOID_P EQUAL 4) set(glog_option ${glog_option} -DHAVE_DBGHELP=ON) endif() endif() mindspore_add_pkg(glog VER 0.4.0 LIBS ${glog_lib} URL ${REQ_URL} SHA256 ${SHA256} PATCHES ${glog_patch} CMAKE_OPTION ${glog_option}) include_directories(${glog_INC}) add_library(mindspore_serving::glog ALIAS glog::${glog_lib}) ================================================ FILE: cmake/external_libs/grpc.cmake ================================================ set(grpc_USE_STATIC_LIBS OFF) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(grpc_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -D_FORTIFY_SOURCE=2 -O2 \ -Dgrpc=mindspore_serving_grpc -Dgrpc_impl=mindspore_serving_grpc_impl -Dgrpc_core=mindspore_serving_grpc_core") elseif(${CMAKE_SYSTEM_NAME} MATCHES "Windows") set(grpc_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2") else() set(grpc_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2 \ -Dgrpc=mindspore_serving_grpc -Dgrpc_impl=mindspore_serving_grpc_impl -Dgrpc_core=mindspore_serving_grpc_core") set(grpc_CFLAGS "-fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") if(NOT ENABLE_GLIBCXX) set(grpc_CXXFLAGS "${grpc_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() endif() if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(grpc_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") endif() if(EXISTS ${protobuf_ROOT}/lib64) set(_FINDPACKAGE_PROTOBUF_CONFIG_DIR "${protobuf_ROOT}/lib64/cmake/protobuf") else() set(_FINDPACKAGE_PROTOBUF_CONFIG_DIR "${protobuf_ROOT}/lib/cmake/protobuf") endif() message("grpc using Protobuf_DIR : " ${_FINDPACKAGE_PROTOBUF_CONFIG_DIR}) if(EXISTS ${absl_ROOT}/lib64) set(_FINDPACKAGE_ABSL_CONFIG_DIR "${absl_ROOT}/lib64/cmake/absl") else() set(_FINDPACKAGE_ABSL_CONFIG_DIR "${absl_ROOT}/lib/cmake/absl") endif() message("grpc using absl_DIR : " ${_FINDPACKAGE_ABSL_CONFIG_DIR}) if(EXISTS ${re2_ROOT}/lib64) set(_FINDPACKAGE_RE2_CONFIG_DIR "${re2_ROOT}/lib64/cmake/re2") else() set(_FINDPACKAGE_RE2_CONFIG_DIR "${re2_ROOT}/lib/cmake/re2") endif() message("grpc using re2_DIR : " ${_FINDPACKAGE_RE2_CONFIG_DIR}) if(EXISTS ${openssl_ROOT}) set(_CMAKE_ARGS_OPENSSL_ROOT_DIR "-DOPENSSL_ROOT_DIR:PATH=${openssl_ROOT}") endif() if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/grpc/repository/archive/v1.36.1.tar.gz") set(SHA256 "17a3ac19345a6aeda01b2baba5400e1136b02b44770dbdfe8581255a091aaf87") else() set(REQ_URL "https://github.com/grpc/grpc/archive/v1.36.1.tar.gz") set(SHA256 "adf51558bf3d057a65651880c9814e09e77b61573eb950c2be1142a624d58e69") endif() mindspore_add_pkg(grpc VER 1.36.1 LIBS mindspore_serving_grpc++ mindspore_serving_grpc mindspore_serving_gpr mindspore_serving_upb mindspore_serving_address_sorting EXE grpc_cpp_plugin grpc_python_plugin URL ${REQ_URL} SHA256 ${SHA256} PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/grpc/grpc.patch001 CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DBUILD_SHARED_LIBS=ON -DgRPC_INSTALL:BOOL=ON -DgRPC_BUILD_TESTS:BOOL=OFF -DgRPC_PROTOBUF_PROVIDER:STRING=package -DgRPC_PROTOBUF_PACKAGE_TYPE:STRING=CONFIG -DProtobuf_DIR:PATH=${_FINDPACKAGE_PROTOBUF_CONFIG_DIR} -DgRPC_ZLIB_PROVIDER:STRING=package -DZLIB_ROOT:PATH=${zlib_ROOT} -DgRPC_ABSL_PROVIDER:STRING=package -Dabsl_DIR:PATH=${_FINDPACKAGE_ABSL_CONFIG_DIR} -DgRPC_CARES_PROVIDER:STRING=package -Dc-ares_DIR:PATH=${c-ares_ROOT}/lib/cmake/c-ares -DgRPC_SSL_PROVIDER:STRING=package ${_CMAKE_ARGS_OPENSSL_ROOT_DIR} -DgRPC_RE2_PROVIDER:STRING=package -Dre2_DIR:PATH=${_FINDPACKAGE_RE2_CONFIG_DIR} ) include_directories(${grpc_INC}) add_library(mindspore_serving::grpc++ ALIAS grpc::mindspore_serving_grpc++) # link other grpc libs target_link_libraries(grpc::mindspore_serving_grpc++ INTERFACE grpc::mindspore_serving_grpc grpc::mindspore_serving_gpr grpc::mindspore_serving_upb grpc::mindspore_serving_address_sorting) # modify mindspore macro define add_compile_definitions(grpc=mindspore_serving_grpc) add_compile_definitions(grpc_impl=mindspore_serving_grpc_impl) add_compile_definitions(grpc_core=mindspore_serving_grpc_core) function(ms_grpc_generate c_var h_var) if(NOT ARGN) message(SEND_ERROR "Error: ms_grpc_generate() called without any proto files") return() endif() set(${c_var}) set(${h_var}) foreach(proto_file_with_path ${ARGN}) message(proto_file_with_path: ${proto_file_with_path}) get_filename_component(proto_file_absolute "${proto_file_with_path}" ABSOLUTE) message(proto_file_absolute: ${proto_file_absolute}) get_filename_component(file_dir ${proto_file_absolute} DIRECTORY) get_filename_component(proto_I_DIR "${file_dir}/../../" ABSOLUTE) get_filename_component(proto_file ${proto_file_absolute} NAME) get_filename_component(proto_file_prefix ${proto_file_absolute} NAME_WE) set(proto_file_relative "mindspore_serving/proto/${proto_file}") set(protoc_output_prefix ${CMAKE_BINARY_DIR}/mindspore_serving/proto) set(hw_proto_srcs "${protoc_output_prefix}/${proto_file_prefix}.pb.cc") set(hw_proto_hdrs "${protoc_output_prefix}/${proto_file_prefix}.pb.h") set(hw_grpc_srcs "${protoc_output_prefix}/${proto_file_prefix}.grpc.pb.cc") set(hw_grpc_hdrs "${protoc_output_prefix}/${proto_file_prefix}.grpc.pb.h") set(hw_py_pb2 "${protoc_output_prefix}/${proto_file_prefix}_pb2.py") set(hw_py_pb2_grpc "${protoc_output_prefix}/${proto_file_prefix}_pb2_grpc.py") add_custom_command( OUTPUT ${hw_proto_srcs} ${hw_proto_hdrs} ${hw_grpc_srcs} ${hw_grpc_hdrs} ${hw_py_pb2} ${hw_py_pb2_grpc} WORKING_DIRECTORY ${proto_I_DIR} COMMAND $ ARGS --grpc_out "${CMAKE_BINARY_DIR}" --cpp_out "${CMAKE_BINARY_DIR}" -I "${proto_I_DIR}" --plugin=protoc-gen-grpc=$ "${proto_file_relative}" COMMAND $ ARGS --grpc_out "${CMAKE_BINARY_DIR}" --python_out "${CMAKE_BINARY_DIR}" -I "${proto_I_DIR}" --plugin=protoc-gen-grpc=$ "${proto_file_relative}" DEPENDS "${proto_file_absolute}") list(APPEND ${c_var} ${hw_proto_srcs} ${hw_grpc_srcs}) list(APPEND ${h_var} ${hw_proto_hdrs} ${hw_grpc_hdrs}) endforeach() set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) set(${c_var} ${${c_var}} PARENT_SCOPE) set(${h_var} ${${h_var}} PARENT_SCOPE) endfunction() ================================================ FILE: cmake/external_libs/gtest.cmake ================================================ set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON) if(NOT ENABLE_GLIBCXX) set(gtest_CXXFLAGS "${gtest_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.1.tar.gz") set(SHA256 "9bf1fe5182a604b4135edc1a425ae356c9ad15e9b23f9f12a02e80184c3a249c") else() set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.1.tar.gz") set(SHA256 "9bf1fe5182a604b4135edc1a425ae356c9ad15e9b23f9f12a02e80184c3a249c") endif() mindspore_add_pkg(gtest VER 1.8.1 LIBS gtest gmock URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION ${CMAKE_OPTION}) include_directories(${gtest_INC}) add_library(mindspore_serving::gtest ALIAS gtest::gtest) add_library(mindspore_serving::gmock ALIAS gtest::gmock) if(CMAKE_SYSTEM_NAME MATCHES "Windows") file(COPY ${gtest_DIRPATH}/bin/libgtest${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_DIRPATH}/bin/libgtest_main${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_DIRPATH}/bin/libgmock_main${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_DIRPATH}/bin/libgmock${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) else() file(COPY ${gtest_LIBPATH}/libgtest${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_LIBPATH}/libgtest_main${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_LIBPATH}/libgmock${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(COPY ${gtest_LIBPATH}/libgmock_main${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) endif() ================================================ FILE: cmake/external_libs/json.cmake ================================================ if(MSVC) set(flatbuffers_CXXFLAGS "${CMAKE_CXX_FLAGS}") set(flatbuffers_CFLAGS "${CMAKE_CXX_FLAGS}") set(flatbuffers_LDFLAGS "${CMAKE_SHARED_LINKER_FLAGS}") else() set(nlohmann_json3101_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(nlohmann_json3101_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") endif() if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip") set(SHA256 "5c7d0a0542431fef628f8dc4c34fd022fe8747ccb577012d58f38672d8747e0d") set(INCLUDE "./include") else() set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.10.1/include.zip") set(SHA256 "144268f7f85afb0f0fbea7c796723c849724c975f9108ffdadde9ecedaa5f0b1") set(INCLUDE "./include") endif() mindspore_add_pkg(nlohmann_json3101 VER 3.10.1 HEAD_ONLY ${INCLUDE} URL ${REQ_URL} SHA256 ${SHA256}) include_directories(${nlohmann_json3101_INC}) add_library(mindspore_serving::json ALIAS nlohmann_json3101) ================================================ FILE: cmake/external_libs/libevent.cmake ================================================ set(openssl_USE_STATIC_LIBS ON) set(libevent_CFLAGS "-fPIC -fvisibility=hidden -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") if(NOT CMAKE_SYSTEM_NAME MATCHES "Darwin") set(libevent_LDFLAGS "-Wl,-z,now") endif() if(NOT MINDSPORE_PROJECT_DIR) set(MINDSPORE_PROJECT_DIR ${CMAKE_SOURCE_DIR}) endif() if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. set(REQ_URL "https://gitee.com/mirrors/libevent/repository/archive/release-2.1.12-stable.tar.gz") set(SHA256 "7180a979aaa7000e1264da484f712d403fcf7679b1e9212c4e3d09f5c93efc24") else() set(REQ_URL "https://github.com/libevent/libevent/releases/download/release-2.1.12-stable/libevent-2.1.12-stable.tar.gz") set(SHA256 "92e6de1be9ec176428fd2367677e61ceffc2ee1cb119035037a27d346b0403bb") endif() message("libevent using openssl stub dir: " ${openssl_ROOT}) mindspore_add_pkg(libevent VER 2.1.12 LIBS event event_pthreads event_core event_openssl URL ${REQ_URL} SHA256 ${SHA256} PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/libevent/libevent.patch001 CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DBUILD_TESTING=OFF -DOPENSSL_ROOT_DIR:PATH=${openssl_ROOT} -DEVENT__LIBRARY_TYPE:STRING=STATIC) include_directories(${libevent_INC}) add_library(mindspore_serving::event ALIAS libevent::event) add_library(mindspore_serving::event_pthreads ALIAS libevent::event_pthreads) add_library(mindspore_serving::event_core ALIAS libevent::event_core) add_library(mindspore_serving::event_openssl ALIAS libevent::event_openssl) ================================================ FILE: cmake/external_libs/openssl.cmake ================================================ if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. set(REQ_URL "https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz") set(SHA256 "b92f9d3d12043c02860e5e602e50a73ed21a69947bcc74d391f41148e9f6aa95") else() set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz") set(SHA256 "b92f9d3d12043c02860e5e602e50a73ed21a69947bcc74d391f41148e9f6aa95") endif() set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl) if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE) set(openssl_CFLAGS -fvisibility=hidden) mindspore_add_pkg(openssl VER 1.1.1k LIBS ssl crypto URL ${REQ_URL} SHA256 ${SHA256} CONFIGURE_COMMAND ./config no-zlib no-shared PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-4160.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-0778.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-1292.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-2068.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-2097.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-4304.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2022-4450.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-0215.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-0286.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-0464.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-0465.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-0466.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-2650.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-3446.patch PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2023-4807.patch ) include_directories(${openssl_INC}) add_library(mindspore_serving::ssl ALIAS openssl::ssl) add_library(mindspore_serving::crypto ALIAS openssl::crypto) endif() ================================================ FILE: cmake/external_libs/protobuf.cmake ================================================ set(protobuf_USE_STATIC_LIBS ON) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC \ -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") elseif(${CMAKE_SYSTEM_NAME} MATCHES "Windows") set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") else() set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") if(NOT ENABLE_GLIBCXX) set(protobuf_CXXFLAGS "${protobuf_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() endif() set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") set(SHA256 "ab9b39e7053a6fb06b01bf75fb6ec6a71a1ada5a5f8e2446f927336e97b9e7bb") else() set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") set(SHA256 "9b4ee22c250fe31b16f1a24d61467e40780a3fbb9b91c3b65be2a376ed913a1a") endif() set(PROTOBUF_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/protobuf) mindspore_add_pkg(protobuf VER 3.13.0 LIBS protobuf EXE protoc URL ${REQ_URL} SHA256 ${SHA256} CMAKE_PATH cmake/ CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release PATCHES ${PROTOBUF_PATCH_ROOT}/CVE-2021-22570.patch PATCHES ${PROTOBUF_PATCH_ROOT}/CVE-2022-1941.patch) include_directories(${protobuf_INC}) include_directories(${CMAKE_BINARY_DIR}/proto_py) add_library(mindspore_serving::protobuf ALIAS protobuf::protobuf) set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) # recover original value if(MSVC) set(CMAKE_STATIC_LIBRARY_PREFIX, ${_ms_tmp_CMAKE_STATIC_LIBRARY_PREFIX}) endif() ================================================ FILE: cmake/external_libs/pybind11.cmake ================================================ set(PYTHON_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}) if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. if(PYTHON_VERSION MATCHES "3.9") set(REQ_URL "https://gitee.com/mirrors/pybind11/repository/archive/v2.6.1.tar.gz") set(SHA256 "c840509be94ac97216c3b4a3ed9f3fdba9948dbe38c16fcfaee3acc6dc93ed0e") elseif(PYTHON_VERSION MATCHES "3.8") set(REQ_URL "https://gitee.com/mirrors/pybind11/repository/archive/v2.6.1.tar.gz") set(SHA256 "c840509be94ac97216c3b4a3ed9f3fdba9948dbe38c16fcfaee3acc6dc93ed0e") elseif(PYTHON_VERSION MATCHES "3.7") set(REQ_URL "https://gitee.com/mirrors/pybind11/repository/archive/v2.4.3.tar.gz") set(SHA256 "182cf9e2c5a7ae6f03f84cf17e826d7aa2b02aa2f3705db684dfe686c0278b36") else() message("Could not find 'Python 3.8' or 'Python 3.7' or 'Python 3.9'") return() endif() else() if(PYTHON_VERSION MATCHES "3.9") set(REQ_URL "https://github.com/pybind/pybind11/archive/v2.6.1.tar.gz") set(SHA256 "cdbe326d357f18b83d10322ba202d69f11b2f49e2d87ade0dc2be0c5c34f8e2a") elseif(PYTHON_VERSION MATCHES "3.8") set(REQ_URL "https://github.com/pybind/pybind11/archive/v2.6.1.tar.gz") set(SHA256 "cdbe326d357f18b83d10322ba202d69f11b2f49e2d87ade0dc2be0c5c34f8e2a") elseif(PYTHON_VERSION MATCHES "3.7") set(REQ_URL "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz") set(SHA256 "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d") else() message("Could not find 'Python 3.8' or 'Python 3.7' or 'Python 3.9'") return() endif() endif() set(pybind11_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(pybind11_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(pybind11_patch ${CMAKE_SOURCE_DIR}/third_party/patch/pybind11/pybind11.patch001) if(PYTHON_VERSION MATCHES "3.9") mindspore_add_pkg(pybind11 VER 2.6.1 URL ${REQ_URL} SHA256 ${SHA256} PATCHES ${pybind11_patch} CMAKE_OPTION -DPYBIND11_TEST=OFF -DPYBIND11_LTO_CXX_FLAGS=FALSE ) elseif(PYTHON_VERSION MATCHES "3.8") mindspore_add_pkg(pybind11 VER 2.6.1 URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DPYBIND11_TEST=OFF -DPYBIND11_LTO_CXX_FLAGS=FALSE ) else() mindspore_add_pkg(pybind11 VER 2.4.3 URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DPYBIND11_TEST=OFF -DPYBIND11_LTO_CXX_FLAGS=FALSE ) endif() include_directories(${pybind11_INC}) find_package(pybind11 REQUIRED) set_property(TARGET pybind11::module PROPERTY IMPORTED_GLOBAL TRUE) add_library(mindspore_serving::pybind11_module ALIAS pybind11::module) ================================================ FILE: cmake/external_libs/re2.cmake ================================================ if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/re2/repository/archive/2019-12-01.tar.gz") set(SHA256 "7268e1b4254d9ffa5ccf010fee954150dbb788fd9705234442e7d9f0ee5a42d3") else() set(REQ_URL "https://github.com/google/re2/archive/2019-12-01.tar.gz") set(SHA256 "7268e1b4254d9ffa5ccf010fee954150dbb788fd9705234442e7d9f0ee5a42d3") endif() if(NOT ENABLE_GLIBCXX) set(re2_CXXFLAGS "${re2_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() mindspore_add_pkg(re2 VER 20191201 LIBS re2 URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=TRUE) include_directories(${re2_INC}) add_library(mindspore_serving::re2 ALIAS re2::re2) ================================================ FILE: cmake/external_libs/zlib.cmake ================================================ if(ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/zlib/repository/archive/v1.2.11.tar.gz") set(SHA256 "f21b3885cc7732f0ab93dbe06ff1ec58069bb58657b3fda89531d1562d8ad708") else() set(REQ_URL "https://github.com/madler/zlib/archive/v1.2.11.tar.gz") set(SHA256 "629380c90a77b964d896ed37163f5c3a34f6e6d897311f1df2a7016355c45eff") endif() mindspore_add_pkg(zlib VER 1.2.11 LIBS z URL ${REQ_URL} SHA256 ${SHA256} CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/zlib/CVE-2018-25032.patch PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/zlib/CVE-2022-37434.patch) include_directories(${zlib_INC}) add_library(mindspore_serving::z ALIAS zlib::z) ================================================ FILE: cmake/mind_expression.cmake ================================================ set(SECURE_CXX_FLAGS "") if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(SECURE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") endif() set(_ms_tmp_CMAKE_CXX_FLAGS_F ${CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") # define third party library download function include(cmake/utils.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) # build dependencies of gRPC include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zlib.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/openssl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/re2.cmake) # build gRPC include(${CMAKE_SOURCE_DIR}/cmake/external_libs/grpc.cmake) # build event include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libevent.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pybind11.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake) set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS_F}) if(MS_BACKEND) include(${CMAKE_SOURCE_DIR}/cmake/dependency_ms.cmake) endif() ================================================ FILE: cmake/options.cmake ================================================ option(DEBUG_MODE "Debug mode, default off" OFF) option(ENABLE_COVERAGE "Enable code coverage report" OFF) option(ENABLE_PYTHON "Enable python" ON) option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs") option(MS_WHL_LIB_PATH "MindSpore lib path") option(MS_BACKEND "Compile MindSpore") option(RUN_TESTCASES "Compile UT") if(MS_WHL_LIB_PATH) message("MindSpore whl lib path:" ${MS_WHL_LIB_PATH}) elseif(MS_BACKEND) message("MindSpore backend method:" ${MS_BACKEND}) elseif(MS_BACKEND_HEADER) message("MindSpore backend method:" ${MS_BACKEND_HEADER}) elseif(RUN_TESTCASES) message("MindSpore Serving Compile UT:" ${RUN_TESTCASES}) elseif() message(FATAL_ERROR "Please confirm how to use MindSpore.") endif() if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND Linux) set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") endif() if(ENABLE_COVERAGE) set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -ftest-coverage") set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}") endif() if(ENABLE_ASAN) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fsanitize-recover=address \ -fno-omit-frame-pointer -fsanitize=undefined") else() set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -static-libsan \ -fsanitize=undefined") endif() endif() if(DEBUG_MODE) set(CMAKE_BUILD_TYPE "Debug") add_compile_definitions(MEM_REUSE_DEBUG) else() set(CMAKE_BUILD_TYPE "Release") endif() if((CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") OR (CMAKE_BUILD_TYPE STREQUAL Release)) set(PYBIND11_LTO_CXX_FLAGS FALSE) endif() if(NOT BUILD_PATH) set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") endif() if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") set(MS_BUILD_GRPC ON) endif() add_compile_definitions(USE_GLOG) ================================================ FILE: cmake/package.cmake ================================================ # include dependency include(CMakePackageConfigHelpers) include(GNUInstallDirs) # set package information set(CPACK_PACKAGE_NAME ${PROJECT_NAME}) set(CPACK_GENERATOR "External") set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${CMAKE_SOURCE_DIR}/cmake/package_script.cmake) set(CPACK_EXTERNAL_ENABLE_STAGING true) set(CPACK_TEMPORARY_PACKAGE_FILE_NAME ${CMAKE_SOURCE_DIR}/build/package/mindspore_serving) set(CPACK_TEMPORARY_INSTALL_DIRECTORY ${CMAKE_SOURCE_DIR}/build/package/mindspore_serving) set(CPACK_MS_PACKAGE_NAME "mindspore_serving") include(CPack) # set install path set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") set(INSTALL_PY_DIR ".") set(INSTALL_BASE_DIR ".") set(INSTALL_LIB_DIR "lib") # grpc install(FILES ${grpc_LIBPATH}/libmindspore_serving_grpc++.so.1.36.1 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_grpc++.so.1 COMPONENT mindspore_serving) install(FILES ${grpc_LIBPATH}/libmindspore_serving_grpc.so.15.0.0 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_grpc.so.15 COMPONENT mindspore_serving) install(FILES ${grpc_LIBPATH}/libmindspore_serving_gpr.so.15.0.0 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_gpr.so.15 COMPONENT mindspore_serving) install(FILES ${grpc_LIBPATH}/libmindspore_serving_upb.so.15.0.0 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_upb.so.15 COMPONENT mindspore_serving) install(FILES ${grpc_LIBPATH}/libmindspore_serving_address_sorting.so.15.0.0 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_address_sorting.so.15 COMPONENT mindspore_serving) # glog install(FILES ${glog_LIBPATH}/libmindspore_serving_glog.so.0.4.0 DESTINATION ${INSTALL_LIB_DIR} RENAME libmindspore_serving_glog.so.0 COMPONENT mindspore) # set python files file(GLOB MS_PY_LIST ${CMAKE_SOURCE_DIR}/mindspore_serving/*.py) install( FILES ${MS_PY_LIST} DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore_serving ) install( TARGETS _mindspore_serving DESTINATION ${INSTALL_BASE_DIR} COMPONENT mindspore_serving ) install( TARGETS serving_common DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore_serving ) install( TARGETS serving_ascend DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore_serving ) install( DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore_serving/server ${CMAKE_SOURCE_DIR}/mindspore_serving/client DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore_serving ) install( FILES ${CMAKE_BINARY_DIR}/mindspore_serving/proto/ms_service_pb2.py ${CMAKE_BINARY_DIR}/mindspore_serving/proto/ms_service_pb2_grpc.py DESTINATION ${INSTALL_PY_DIR}/proto COMPONENT mindspore_serving ) ================================================ FILE: cmake/package_script.cmake ================================================ # find exec find_package(Python3 3.7 COMPONENTS Interpreter) if(NOT Python3_FOUND) message(FATAL_ERROR "No python3 found.") endif() set(PYTHON ${Python3_EXECUTABLE}) set(PYTHON_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}) if(NOT (PYTHON_VERSION MATCHES "3.7" OR PYTHON_VERSION MATCHES "3.8" OR PYTHON_VERSION MATCHES "3.9")) message(FATAL_ERROR "FIND PYTHON VERSION ${PYTHON_VERSION} BUT CAN NOT MATCH PYTHON VERSION 3.7, 3.8 OR 3.9") endif() find_package(Git) if(NOT GIT_FOUND) message("No git found.") return() endif() set(GIT ${GIT_EXECUTABLE}) # set path set(MS_ROOT_DIR ${CPACK_PACKAGE_DIRECTORY}/../../) set(MS_PACK_ROOT_DIR ${MS_ROOT_DIR}/build/package) # set package file name if(CMAKE_SYSTEM_NAME MATCHES "Linux") if(PYTHON_VERSION MATCHES "3.7") set(PY_TAGS "cp37-cp37m") elseif(PYTHON_VERSION MATCHES "3.8") set(PY_TAGS "cp38-cp38") elseif(PYTHON_VERSION MATCHES "3.9") set(PY_TAGS "cp39-cp39") else() message("Could not find 'Python 3.7', 'Python 3.8' or 'Python 3.9'") return() endif() string(TOLOWER linux_${CMAKE_HOST_SYSTEM_PROCESSOR} PLATFORM_TAG) elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") if(PYTHON_VERSION MATCHES "3.7") set(PY_TAGS "py37-none") elseif(PYTHON_VERSION MATCHES "3.8") set(PY_TAGS "py38-none") elseif(PYTHON_VERSION MATCHES "3.9") set(PY_TAGS "py39-none") else() message("Could not find 'Python 3.7', 'Python 3.8' or 'Python 3.9'") return() endif() set(PLATFORM_TAG "any") elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") if(PYTHON_VERSION MATCHES "3.7") set(PY_TAGS "cp37-cp37m") elseif(PYTHON_VERSION MATCHES "3.8") set(PY_TAGS "cp38-cp38") elseif(PYTHON_VERSION MATCHES "3.9") set(PY_TAGS "cp39-cp39") else() message("Could not find 'Python 3.7', 'Python 3.8' or 'Python 3.9'") return() endif() set(PLATFORM_TAG "win_amd64") else() message(FATAL_ERROR "other platform: ${CMAKE_SYSTEM_NAME}") endif() # get git commit id set(GIT_COMMIT_ID "") execute_process( COMMAND ${GIT} log --format='[sha1]:%h,[branch]:%d' --abbrev=8 -1 OUTPUT_VARIABLE GIT_COMMIT_ID WORKING_DIRECTORY ${MS_ROOT_DIR} ERROR_QUIET) string(REPLACE " " "" GIT_COMMIT_ID ${GIT_COMMIT_ID}) set(ENV{MS_PACKAGE_NAME} ${CPACK_MS_PACKAGE_NAME}) set(ENV{COMMIT_ID} ${GIT_COMMIT_ID}) execute_process( COMMAND ${PYTHON} ${MS_ROOT_DIR}/setup.py "bdist_wheel" WORKING_DIRECTORY ${MS_PACK_ROOT_DIR} ) # finally set(PACKAGE_NAME ${CPACK_MS_PACKAGE_NAME}) if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") string(REPLACE "-" "_" PACKAGE_NAME ${PACKAGE_NAME}) execute_process( COMMAND chmod -R 700 ${MS_PACK_ROOT_DIR}/mindspore_serving/ COMMAND chmod -R 700 ${MS_PACK_ROOT_DIR}/${PACKAGE_NAME}.egg-info/ ) endif() file(GLOB WHL_FILE ${MS_PACK_ROOT_DIR}/dist/*.whl) get_filename_component(ORIGIN_FILE_NAME ${WHL_FILE} NAME) string(REPLACE "-" ";" ORIGIN_FILE_NAME ${ORIGIN_FILE_NAME}) list(GET ORIGIN_FILE_NAME 1 VERSION) set(NEW_FILE_NAME ${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG}.whl) file(RENAME ${WHL_FILE} ${MS_PACK_ROOT_DIR}/${NEW_FILE_NAME}) file(REMOVE_RECURSE ${MS_ROOT_DIR}/output) file(MAKE_DIRECTORY ${MS_ROOT_DIR}/output) file(COPY ${MS_PACK_ROOT_DIR}/${NEW_FILE_NAME} DESTINATION ${MS_ROOT_DIR}/output/) file(SHA256 ${MS_ROOT_DIR}/output/${NEW_FILE_NAME} SHA256_VAR) file(WRITE ${MS_ROOT_DIR}/output/${NEW_FILE_NAME}.sha256 ${SHA256_VAR} " " ${NEW_FILE_NAME}) ================================================ FILE: cmake/utils.cmake ================================================ include(FetchContent) # 下载第三方库 set(FETCHCONTENT_QUIET OFF) function(mindspore_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj) add_subdirectory(${sub_dir}) if(NOT TARGET ${submodule_name_obj}) message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}") endif() if("$" IN_LIST ${des_submodule_objs}) message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}") endif() set(${des_submodule_objs} ${${des_submodule_objs}} $ PARENT_SCOPE) endfunction() if(DEFINED ENV{MSLIBS_CACHE_PATH}) set(_MS_LIB_CACHE $ENV{MSLIBS_CACHE_PATH}) else() set(_MS_LIB_CACHE ${CMAKE_BINARY_DIR}/.mslib) endif() message("MS LIBS CACHE PATH: ${_MS_LIB_CACHE}") if(NOT EXISTS ${_MS_LIB_CACHE}) file(MAKE_DIRECTORY ${_MS_LIB_CACHE}) endif() if(DEFINED ENV{MSLIBS_SERVER}) # export MSLIBS_SERVER=49.4.0.74 set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER}) message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}") endif() include(ProcessorCount) # 确定处理器/核的数量并将值保存在${var}中 ProcessorCount(N) if(JOBS) set(THNUM ${JOBS}) else() set(JOBS 8) if(${JOBS} GREATER ${N}) set(THNUM ${N}) else() set(THNUM ${JOBS}) endif() endif() message("set make thread num: ${THNUM}") if(LOCAL_LIBS_SERVER) if(NOT ENV{no_proxy}) set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}") else() string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS) if(${IP_POS} EQUAL -1) set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}") endif() endif() endif() function(__download_pkg pkg_name pkg_url pkg_sha256) if(LOCAL_LIBS_SERVER) get_filename_component(_URL_FILE_NAME ${pkg_url} NAME) set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) endif() FetchContent_Declare( # 获取项目。可以是一个URL也可以是一个Git仓库。 ${pkg_name} URL ${pkg_url} URL_HASH SHA256=${pkg_sha256} ) FetchContent_GetProperties(${pkg_name}) # 获取我们需要的变量MyName_*。 message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") if(NOT ${pkg_name}_POPULATED) FetchContent_Populate(${pkg_name}) # 将信息记录在可以随时查询的全局属性中 set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) endif() endfunction() function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_sha256) if(LOCAL_LIBS_SERVER) set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}") FetchContent_Declare( ${pkg_name} URL ${pkg_url} URL_HASH SHA256=${pkg_sha256} ) else() FetchContent_Declare( ${pkg_name} GIT_REPOSITORY ${pkg_url} GIT_TAG ${pkg_git_commit}) endif() FetchContent_GetProperties(${pkg_name}) message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") if(NOT ${pkg_name}_POPULATED) FetchContent_Populate(${pkg_name}) set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) endif() endfunction() function(__find_pkg_then_add_target_lib pkg_name lib_path) unset(${pkg_name}_LIBS) message("_FIND:${${pkg_name}_BASE_DIR}") foreach(_LIB_NAME ${ARGN}) set(_LIB_SEARCH_NAME ${_LIB_NAME}) set(_LIB_TYPE SHARED) if(${pkg_name}_USE_STATIC_LIBS) set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") set(_LIB_TYPE STATIC) endif() set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND) find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/${lib_path} NO_DEFAULT_PATH) if(NOT ${_LIB_NAME}_LIB AND BUILD_LITE AND PLATFORM_ARM) set(${_LIB_NAME}_LIB "${${pkg_name}_BASE_DIR}/${lib_path}/lib${_LIB_SEARCH_NAME}.so") endif() if(NOT ${_LIB_NAME}_LIB) return() endif() add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL) if(WIN32 AND ${_LIB_TYPE} STREQUAL "SHARED") set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES IMPORTED_IMPLIB_RELEASE ${${_LIB_NAME}_LIB}) else() set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES IMPORTED_LOCATION ${${_LIB_NAME}_LIB}) endif() if(EXISTS ${${pkg_name}_BASE_DIR}/include) set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include") endif() list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME}) message("found ${${_LIB_NAME}_LIB}") STRING(REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB}) set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL) endforeach() set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE) endfunction() function(__find_pkg_then_add_target_exe pkg_name lib_path) message("_FIND:${${pkg_name}_BASE_DIR}") foreach(pkg_exe ${ARGN}) # find_program:该命令用于查找程序。创建名为的缓存条目以存储此命令的结果。 # 如果找到程序,则结果存储在变量中,除非清除变量,否则将不会重复搜索。如果什么也没找到,结果将是-NOTFOUND。 find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH) if(NOT ${pkg_exe}_EXE) return() endif() # add_executable: 使用给定的源文件,为工程引入一个可执行文件。 # IMPORTED:一个导入的可执行目标引用了一个位于工程之外的可执行文件。 add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL) set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES IMPORTED_LOCATION ${${pkg_exe}_EXE} ) message("found ${${pkg_exe}_EXE}") endforeach() endfunction() function(__exec_cmd) set(options) set(oneValueArgs WORKING_DIRECTORY) set(multiValueArgs COMMAND) cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) execute_process(COMMAND ${EXEC_COMMAND} WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY} RESULT_VARIABLE RESULT) if(NOT RESULT EQUAL "0") message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}") endif() endfunction() function(__check_patches pkg_patches) # check patches if(PKG_PATCHES) file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256) file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256 ${pkg_name}_PATCHES_SHA256) message("patches sha256:${${pkg_name}_PATCHES_SHA256}") set(${pkg_name}_PATCHES_NEW_SHA256) foreach(_PATCH ${PKG_PATCHES}) file(SHA256 ${_PATCH} _PF_SHA256) set(${pkg_name}_PATCHES_NEW_SHA256 "${${pkg_name}_PATCHES_NEW_SHA256},${_PF_SHA256}") endforeach() if(NOT ${pkg_name}_PATCHES_SHA256 STREQUAL ${pkg_name}_PATCHES_NEW_SHA256) set(${pkg_name}_PATCHES ${PKG_PATCHES}) file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild") file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256 ${${pkg_name}_PATCHES_NEW_SHA256}) message("patches changed : ${${pkg_name}_PATCHES_NEW_SHA256}") endif() endif() endfunction() set(MS_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH NO_CMAKE_SYSTEM_PACKAGE_REGISTRY) set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) function(mindspore_add_pkg pkg_name) message("---------add pkg: " ${pkg_name} "---------") set(options) set(oneValueArgs URL SHA256 GIT_REPOSITORY GIT_TAG VER DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH CUSTOM_CMAKE) set(multiValueArgs CMAKE_OPTION LIBS EXE PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES ONLY_MAKE ONLY_MAKE_INCS ONLY_MAKE_LIBS) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if(NOT PKG_LIB_PATH) set(PKG_LIB_PATH lib) endif() if(NOT PKG_EXE) set(PKG_EXE 0) endif() set(__FIND_PKG_NAME ${pkg_name}) string(TOLOWER ${pkg_name} pkg_name) message("pkg name:${__FIND_PKG_NAME},${pkg_name}") set(${pkg_name}_PATCHES_HASH) foreach(_PATCH ${PKG_PATCHES}) file(SHA256 ${_PATCH} _PF_SHA256) set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_SHA256}") endforeach() # check options set(${pkg_name}_CONFIG_TXT "${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION} ${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH} ${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}") string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT}) string(SHA256 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT}) message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}") set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH}) set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL) if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) add_library(${pkg_name} INTERFACE) target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) if(${PKG_RELEASE}) __find_pkg_then_add_target_exe(${pkg_name} ${PKG_LIB_PATH} ${PKG_EXE}) __find_pkg_then_add_target_lib(${pkg_name} ${PKG_LIB_PATH} ${PKG_LIBS}) endif() return() endif() set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR}) set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE) if(PKG_LIBS) __find_pkg_then_add_target_exe(${pkg_name} ${PKG_LIB_PATH} ${PKG_EXE}) __find_pkg_then_add_target_lib(${pkg_name} ${PKG_LIB_PATH} ${PKG_LIBS}) if(${pkg_name}_LIBS) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) message("Found libs: ${${pkg_name}_LIBS}") return() endif() elseif(NOT PKG_HEAD_ONLY) find_package(${__FIND_PKG_NAME} ${PKG_VER} ${MS_FIND_NO_DEFAULT_PATH}) if(${__FIND_PKG_NAME}_FOUND) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) message("Found pkg: ${__FIND_PKG_NAME}") return() endif() endif() if(NOT PKG_DIR) if(PKG_GIT_REPOSITORY) __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_SHA256}) else() __download_pkg(${pkg_name} ${PKG_URL} ${PKG_SHA256}) endif() foreach(_SUBMODULE_FILE ${PKG_SUBMODULES}) STRING(REGEX REPLACE "(.+)_(.+)" "\\1" _SUBMODEPATH ${_SUBMODULE_FILE}) STRING(REGEX REPLACE "(.+)/(.+)" "\\2" _SUBMODENAME ${_SUBMODEPATH}) file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) endforeach() else() set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) endif() file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT}) message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}") foreach(_PATCH_FILE ${PKG_PATCHES}) get_filename_component(_PATCH_FILE_NAME ${_PATCH_FILE} NAME) set(_LF_PATCH_FILE ${CMAKE_BINARY_DIR}/_ms_patch/${_PATCH_FILE_NAME}) configure_file(${_PATCH_FILE} ${_LF_PATCH_FILE} NEWLINE_STYLE LF @ONLY) message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_LF_PATCH_FILE}") execute_process(COMMAND ${Patch_EXECUTABLE} -p1 INPUT_FILE ${_LF_PATCH_FILE} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR} RESULT_VARIABLE Result) if(NOT Result EQUAL "0") message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}") endif() endforeach() foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) endforeach() file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) if(NOT ${pkg_name}_LOCK_RET EQUAL "0") message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") endif() if(PKG_CUSTOM_CMAKE) file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt) file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR}) endif() if(${pkg_name}_SOURCE_DIR) if(PKG_HEAD_ONLY) file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) if(NOT PKG_RELEASE) add_library(${pkg_name} INTERFACE) target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) endif() elseif(PKG_ONLY_MAKE) __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_CXXFLAGS} -j${THNUM} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) set(PKG_INSTALL_INCS ${PKG_ONLY_MAKE_INCS}) set(PKG_INSTALL_LIBS ${PKG_ONLY_MAKE_LIBS}) file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) elseif(PKG_CMAKE_OPTION) # in cmake file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) if(${pkg_name}_CFLAGS) set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}") endif() if(${pkg_name}_CXXFLAGS) set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}") endif() if(${pkg_name}_LDFLAGS) if(${pkg_name}_USE_STATIC_LIBS) #set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") else() set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") endif() endif() __exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR} ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ${${pkg_name}_SOURCE_DIR}/${PKG_CMAKE_PATH} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${THNUM} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) else() if(${pkg_name}_CFLAGS) set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}") endif() if(${pkg_name}_CXXFLAGS) set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}") endif() if(${pkg_name}_LDFLAGS) set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}") endif() # in configure && make if(PKG_PRE_CONFIGURE_COMMAND) __exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) endif() if(PKG_CONFIGURE_COMMAND) __exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND} ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS} --prefix=${${pkg_name}_BASE_DIR} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) endif() set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION}) if(NOT PKG_CONFIGURE_COMMAND) set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION} ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}) endif() # build __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j${THNUM} WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) if(PKG_INSTALL_INCS OR PKG_INSTALL_LIBS) file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) else() __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) endif() endif() endif() if(PKG_LIBS) __find_pkg_then_add_target_exe(${pkg_name} ${PKG_LIB_PATH} ${PKG_EXE}) __find_pkg_then_add_target_lib(${pkg_name} ${PKG_LIB_PATH} ${PKG_LIBS}) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) if(NOT ${pkg_name}_LIBS) message(FATAL_ERROR "Can not find pkg: ${pkg_name}") endif() else() find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET ${MS_FIND_NO_DEFAULT_PATH}) if(${__FIND_PKG_NAME}_FOUND) set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}") return() endif() endif() endfunction() ================================================ FILE: docs/api/api_python/client/mindspore_serving.client.Client.rst ================================================  .. py:class:: mindspore_serving.client.Client(address, servable_name, method_name, version_number=0, ssl_config=None) 通过Client访问Serving服务器gRPC接口,可用于创建请求、访问服务和解析结果。 .. note:: Serving客户端在一个请求中可以发送的最大数据量为512MB,Serving服务器可以接收的最大数据量可以配置为1~512MB,默认为100MB。 参数: - **address** (str) - Serving服务器gRPC接口地址。 - **servable_name** (str) - Serving服务器提供的服务的名称。 - **method_name** (str) - 服务中方法的名称。 - **version_number** (int, optional) - 服务的版本号,``0`` 表示指定所有正在运行的一个或多个版本的服务中最大的版本号。默认值:``0``。 - **ssl_config** (mindspore_serving.client.SSLConfig, optional) - SSL配置,如果 ``None``,则禁用SSL。默认值:``None``。 异常: - **RuntimeError** - 参数的类型或值无效,或发生其他错误。 .. py:method:: infer(instances) 用于创建请求、访问服务、解析和返回结果。 参数: - **instances** (Union[dict, tuple[dict]]) - 一个实例或一组实例的输入,每个实例都是dict。dict的key是输入名称,value是输入值。value的类型可以是Python int、float、bool、str、bytes、numpy scalar或numpy array对象。 异常: - **RuntimeError** - 参数的类型或值无效,或发生其他错误。 .. py:method:: infer_async(instances) 用于创建请求,异步访问服务。 参数: - **instances** (Union[dict, tuple[dict]]) - 一个实例或一组实例的输入,每个实例都是dict。dict的key是输入名称,value是输入值。value的类型可以是Python int、float、bool、str、bytes、numpy scalar或numpy array对象。 异常: - **RuntimeError** - 参数的类型或值无效,或发生其他错误。 ================================================ FILE: docs/api/api_python/client/mindspore_serving.client.SSLConfig.rst ================================================  .. py:class:: mindspore_serving.client.SSLConfig(certificate=None, private_key=None, custom_ca=None) Serving服务器gRPC SSL使能时,通过SSLConfig封装SSL证书等相关参数。 参数: - **certificate** (str, 可选) - PEM编码的证书链内容,如果为 ``None``,表示不使用证书链。默认值:``None``。 - **private_key** (str, 可选) - PEM编码的私钥内容,如果为 ``None``,表示不使用私钥。默认值:``None``。 - **custom_ca** (str, 可选) - PEM编码的根证书内容,如果为 ``None``,gRPC运行时将从默认位置加载根证书。默认值:``None``。 异常: - **RuntimeError** - 参数的类型或值无效。 ================================================ FILE: docs/api/api_python/client/mindspore_serving.client.rst ================================================ MindSpore Serving客户端API,用于通过gRPC访问MindSpore Serving服务器。 ================================================ FILE: docs/api/api_python/mindspore_serving.client.rst ================================================ mindspore_serving.client ========================== .. include:: client/mindspore_serving.client.rst .. include:: client/mindspore_serving.client.Client.rst .. include:: client/mindspore_serving.client.SSLConfig.rst .. automodule:: mindspore_serving.client :members: ================================================ FILE: docs/api/api_python/mindspore_serving.server.rst ================================================ mindspore_serving.server ========================= .. include:: server/mindspore_serving.server.rst .. include:: server/mindspore_serving.server.start_grpc_server.rst .. include:: server/mindspore_serving.server.start_restful_server.rst .. include:: server/mindspore_serving.server.stop.rst .. include:: server/mindspore_serving.server.start_servables.rst .. include:: server/mindspore_serving.server.ServableStartConfig.rst .. include:: server/mindspore_serving.server.SSLConfig.rst .. automodule:: mindspore_serving.server :members: mindspore_serving.server.register ---------------------------------- .. include:: server/register/mindspore_serving.server.register.rst .. include:: server/register/mindspore_serving.server.register.declare_model.rst .. include:: server/register/mindspore_serving.server.register.Model.rst .. include:: server/register/mindspore_serving.server.register.AscendDeviceInfo.rst .. include:: server/register/mindspore_serving.server.register.CPUDeviceInfo.rst .. include:: server/register/mindspore_serving.server.register.GPUDeviceInfo.rst .. include:: server/register/mindspore_serving.server.register.Context.rst .. include:: server/register/mindspore_serving.server.register.register_method.rst .. include:: server/register/mindspore_serving.server.register.add_stage.rst .. automodule:: mindspore_serving.server.register :members: mindspore_serving.server.distributed ------------------------------------- .. include:: server/distributed/mindspore_serving.server.distributed.rst .. include:: server/distributed/mindspore_serving.server.distributed.start_servable.rst .. include:: server/distributed/mindspore_serving.server.distributed.startup_agents.rst .. include:: server/distributed/mindspore_serving.server.distributed.declare_servable.rst .. automodule:: mindspore_serving.server.distributed :members: ================================================ FILE: docs/api/api_python/server/distributed/mindspore_serving.server.distributed.declare_servable.rst ================================================  .. py:function:: mindspore_serving.server.distributed.declare_servable(rank_size, stage_size, with_batch_dim=True, without_batch_dim_inputs=None, enable_pipeline_infer=False) 用于在servable_config.py中声明分布式服务,详细可参考 `基于MindSpore Serving部署分布式推理服务 `_ 。 参数: - **rank_size** (int) - 分布式模型的rank大小。 - **stage_size** (int) - 分布式模型的stage大小。 - **with_batch_dim** (bool, 可选) - 模型输入和输出shape的第一个维度是否是batch维度。默认值:``True``。 - **without_batch_dim_inputs** (Union[int, tuple[int], list[int]], 可选) - 当 `with_batch_dim` 为 ``True`` 时,用于指定shape不包括batch维度的模型输入的索引,比如模型输入0的shape不包括batch维度,则 `without_batch_dim_inputs=(0,)` 。默认值:``None``。 - **enable_pipeline_infer** (bool, 可选) - 是否开启流水线并行推理,流水线并行可有效提升推理性能,详情可参考 `流水线并行 `_ 。默认值:``False``。 返回: `Model` ,此模型的标识,可以用来调用 `Model.call` 或作为 `add_stage` 的输入。 异常: - **RuntimeError** - 参数的类型或值无效。 ================================================ FILE: docs/api/api_python/server/distributed/mindspore_serving.server.distributed.rst ================================================ Serving服务器启动分布式模型服务的接口。如何配置和启动分布式模型,请查看 `基于MindSpore Serving部署分布式推理服务 `_ 。 ================================================ FILE: docs/api/api_python/server/distributed/mindspore_serving.server.distributed.start_servable.rst ================================================  .. py:function:: mindspore_serving.server.distributed.start_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, distributed_address='0.0.0.0:6200', wait_agents_time_in_seconds=0) 启动在 `servable_directory` 中定义的名为 `servable_name` 的分布式服务。 参数: - **servable_directory** (str) - 服务所在的目录。预期有一个名为 `servable_name` 的目录。详细信息可以查看 `通过配置模型提供Servable `_ 。 - **servable_name** (str) - 服务名称。 - **version_number** (int, 可选) - 要加载的服务版本号。版本号应为正整数,从1开始。默认值:``1``。 - **rank_table_json_file** (str) - rank table json文件名。 - **distributed_address** (str, 可选) - Worker代理(Agent)连接的分布式Worker服务器地址。默认值: ``"0.0.0.0:6200"`` 。 - **wait_agents_time_in_seconds** (int, 可选) - 等待所有Worker代理就绪的最长时间(以秒为单位),``0`` 表示无限时间。默认值:``0``。 异常: - **RuntimeError** - 启动分布式服务失败。 ================================================ FILE: docs/api/api_python/server/distributed/mindspore_serving.server.distributed.startup_agents.rst ================================================  .. py:function:: mindspore_serving.server.distributed.startup_agents(distributed_address, model_files, group_config_files=None, agent_start_port=7000, agent_ip=None, rank_start=None, dec_key=None, dec_mode='AES-GCM') 在当前计算机上启动所有所需的Worker代理(Agent),这组Worker代理进程将负责本机器设备上的推理任务,详细可参考 `基于MindSpore Serving部署分布式推理服务 `_ 。 参数: - **distributed_address** (str) - Worker代理连接分布式Worker服务器地址。 - **model_files** (Union[list[str], tuple[str]]) - 当前计算机中需要的所有模型文件,为绝对路径或相对于此启动Python脚本的路径。 - **group_config_files** (Union[list[str], tuple[str]], 可选) - 当前计算机中需要的所有组配置文件,相对于此启动Python脚本的绝对路径或相对路径,为 ``None`` 时表示没有配置文件。默认值:``None``。 - **agent_start_port** (int, 可选) - Worker代理连接Worker服务器的起始端口号。默认值:``7000``。 - **agent_ip** (str, 可选) - 本地Worker代理ip,如果为无,则代理ip将从rank table文件中获取。参数 `agent_ip` 和参数 `rank_start` 必须同时有值,或者同时是 ``None``。默认值:``None``。 - **rank_start** (int, 可选) - 此计算机的起始rank id,如果为 ``None``,则将从rank table文件中获取rank id。参数 `agent_ip` 和参数 `rank_start` 必须同时有值,或者同时是 ``None``。默认值:``None``。 - **dec_key** (bytes, 可选) - 用于解密的密钥,类型为字节。有效长度为16、24或32。默认值:``None``。 - **dec_mode** (str, 可选) - 指定解密模式,在设置了 `dec_key` 时生效。值可为: ``'AES-GCM'`` 或 ``'AES-CBC'`` 。默认值: ``'AES-GCM'`` 。 异常: - **RuntimeError** - 启动Worker代理失败。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.SSLConfig.rst ================================================  .. py:class:: mindspore_serving.server.SSLConfig(certificate, private_key, custom_ca=None, verify_client=False) Serving服务器中,使能gRPC或RESTful服务器SSL功能时,SSL的参数配置。 参数: - **certificate** (str) - PEM编码的证书链内容,如果值为 ``None``,则表示不使用证书链。 - **private_key** (str) - PEM编码的私钥内容,如果值为 ``None``,则表示不使用私钥。 - **custom_ca** (str, 可选) - PEM编码的根证书内容。当 `verify_client` 为 ``True`` 时, `custom_ca` 必须指定。当 `verify_client` 为 ``False`` 时,将忽略此参数。默认值:``None``。 - **verify_client** (bool, 可选) - 如果 `verify_client` 为 ``True``,则启用客户端服务器双向认证。如果为 ``False``,则仅启用客户端对服务器的单向认证。默认值:``False``。 异常: - **RuntimeError** - 参数的类型或值无效。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.ServableStartConfig.rst ================================================  .. py:class:: mindspore_serving.server.ServableStartConfig(servable_directory, servable_name, device_ids=None, version_number=0, device_type=None, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM') 启动一个服务的配置。详情请查看 `基于MindSpore Serving部署推理服务 `_ 和 `通过配置模型提供Servable `_ 。 参数: - **servable_directory** (str) - 服务所在的目录。预期有一个名为 `servable_name` 的目录。 - **servable_name** (str) - 服务名称。 - **device_ids** (Union[int, list[int], tuple[int]], 可选) - 模型部署和运行的设备列表,列表中的每个会设备将部署和运行一个服务副本。当设备类型为Nvidia GPU、Ascend 310P/910时使用。默认值:``None``。 - **version_number** (int, 可选) - 要加载的服务的版本号。版本号应为正整数,从1开始,``0`` 表示加载最新版本。默认值:``0``。 - **device_type** (str, 可选) - 模型部署的目标设备类型,目前支持 ``"Ascend"``、``"GPU"``、``"CPU"`` 和 ``None``。默认值:``None``。 - ``"Ascend"``:目标设备为Ascend 310P/910等。 - ``"GPU"``:目标设备为Nvidia GPU。 - ``"CPU"``:目标设备为CPU。 - ``None``:系统根据实际的后端设备和MindSpor推理包决定目标设备,推荐使用默认值 ``None``。 - **num_parallel_workers** (int, 可选) - 处理Python任务的进程数,用于提高预处理、后处理等Python任务的处理能力。值小于 `device_ids` 的长度时,处理Python任务的进程数为 `device_ids` 的长度。值的范围为[0,64]。默认值:``0``。 - **dec_key** (bytes, 可选) - 用于解密的字节类型密钥。有效长度为16、24或32。默认值:``None``。 - **dec_mode** (str, 可选) - 指定解密模式,设置 `dec_key` 时生效。值可为: ``'AES-GCM'`` 或 ``'AES-CBC'`` 。默认值: ``'AES-GCM'`` 。 异常: - **RuntimeError** - 参数的类型或值无效。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.rst ================================================ MindSpore Serving是一个轻量级、高性能的服务模块,旨在帮助MindSpore开发者在生产环境中高效部署在线推理服务。 用户可通过MindSpore Serving server API启动服务,启动gRPC和RESTful(HTTP)服务器。其中一个服务一般可由一个模型或者一组模型组合提供。客户端通过gRPC和RESTful(HTTP)服务器发送推理任务,接收推理结果。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.start_grpc_server.rst ================================================  .. py:function:: mindspore_serving.server.start_grpc_server(address, max_msg_mb_size=100, ssl_config=None) 启动gRPC服务器,用于Serving客户端和Serving服务器之间的通信。 参数: - **address** (str) - gRPC服务器地址,地址可以是 `{ip}:{port}` 或 `unix:{unix_domain_file_path}` 。 - `{ip}:{port}` - Internet domain socket地址。 - `unix:{unix_domain_file_path}` - Unix domain socket地址,用于与同一台计算机上的多个进程通信。 `{unix_domain_file_path}` 可以是相对路径或绝对路径,但文件所在的目录必须已经存在。 - **max_msg_mb_size** (int, 可选) - 可接收的最大gRPC消息大小(MB),取值范围[1, 512]。默认值:``100``。 - **ssl_config** (mindspore_serving.server.SSLConfig, 可选) - 服务器的SSL配置,如果 ``None``,则禁用SSL。默认值:``None``。 异常: - **RuntimeError** - 启动gRPC服务器失败:参数校验失败,gRPC地址错误或端口重复。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.start_restful_server.rst ================================================  .. py:function:: mindspore_serving.server.start_restful_server(address, max_msg_mb_size=100, ssl_config=None) 启动RESTful服务器,用于Serving客户端和Serving服务器之间的通信。 参数: - **address** (str) - RESTful服务器地址,地址应为Internet domain socket地址。 - **max_msg_mb_size** (int, 可选) - 最大可接收的RESTful消息大小,以MB为单位,取值范围[1, 512]。默认值:``100``。 - **ssl_config** (mindspore_serving.server.SSLConfig, 可选) - 服务器的SSL配置,如果是 ``None``,则禁用SSL。默认值:``None``。 异常: - **RuntimeError** - 启动RESTful服务器失败:参数校验失败,RESTful地址错误或端口重复。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.start_servables.rst ================================================  .. py:function:: mindspore_serving.server.start_servables(servable_configs, enable_lite=False) 用于Serving服务器中启动一个或多个服务,一个模型可结合预处理、后处理提供一个服务,多个模型也可串接组合提供一个服务。 本接口可以用来启动多个不同的服务。一个服务可以部署在多个设备上,其中每个设备运行一个服务副本。 在Ascend 910硬件平台上,每个服务的每个副本都独占一个设备。不同的服务或同一服务的不同版本需要部署在不同的设备上。在Ascend 310P和GPU硬件平台上,一个设备可以被多个服务共享,不同服务或同一服务的不同版本可以部署在同一设备上,实现设备复用。 如何配置模型提供服务请查看 `基于MindSpore Serving部署推理服务 `_ 和 `通过配置模型提供Servable `_ 。 参数: - **servable_configs** (Union[ServableStartConfig, list[ServableStartConfig], tuple[ServableStartConfig]]) - 一个或多个服务的启动配置。 - **enable_lite** (bool) - 是否使用MindSpore Lite推理后端。 默认值:``False``。 异常: - **RuntimeError** - 启动一个或多个服务失败。相关日志可查看本Serving服务器启动脚本所在目录的子目录serving_logs。 ================================================ FILE: docs/api/api_python/server/mindspore_serving.server.stop.rst ================================================  .. py:function:: mindspore_serving.server.stop() 停止Serving服务器的运行。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.AscendDeviceInfo.rst ================================================  .. py:class:: mindspore_serving.server.register.AscendDeviceInfo(**kwargs) 用于设置Ascend设备配置。 参数: - **insert_op_cfg_path** (str, 可选) - AIPP配置文件的路径。 - **input_format** (str, 可选) - 模型输入格式,取值可以是 ``"ND"`` 、 ``"NCHW"`` 、 ``"NHWC"`` 、 ``"CHWN"`` 、 ``"NC1HWC0"`` 或 ``"NHWC1C0"`` 。 - **input_shape** (str, 可选) - 模型输入形状,如 ``"input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"`` 。 - **output_type** (str, 可选) - 模型输出类型,值可以是 ``"FP16"`` 、 ``"UINT8"`` 或 ``"FP32"`` ,默认值: ``"FP32"`` 。 - **precision_mode** (str, 可选) - 模型精度模式,取值可以是 ``"force_fp16"`` 、 ``"allow_fp32_to_fp16"`` 、 ``"must_keep_origin_dtype"`` 或者 ``"allow_mix_precision"`` 。默认值: ``"force_fp16"`` 。 - **op_select_impl_mode** (str, 可选) - 运算符选择模式,值可以是 ``"high_performance"`` 或 ``"high_precision"`` 。默认值: ``"high_performance"`` 。 - **fusion_switch_config_path** (str, 可选) - 融合配置文件路径,包括图融合和UB融合。系统内置图融合和UB融合规则,默认启用。您可以通过设置此参数禁用指定的融合规则。 - **buffer_optimize_mode** (str, 可选) - 数据缓存优化策略,值可以是 ``"l1_optimize"`` 、 ``"l2_optimize"`` 、 ``"off_optimize"`` 或者 ``"l1_and_l2_optimize"`` 。默认 ``"l2_optimize"`` 。 异常: - **RuntimeError** - Ascend设备配置无效。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.CPUDeviceInfo.rst ================================================  .. py:class:: mindspore_serving.server.register.CPUDeviceInfo(**kwargs) 用于CPU设备配置。 参数: - **precision_mode** (str, 可选) - 推理精度选项,值可以是 ``"origin"`` 或 ``"fp16"`` , ``"origin"`` 表示以模型中指定精度进行推理, ``"fp16"`` 表示以FP16精度进行推理。默认值: ``"origin"`` 。 异常: - **RuntimeError** - 选项无效,或值类型不是字符串。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.Context.rst ================================================  .. py:class:: mindspore_serving.server.register.Context(**kwargs) Context用于自定义设备配置,如果不指定Context,MindSpore Serving将使用默认设备配置。当使用推理后端为MindSpore Lite,且目标设备为Ascend或Nvidia GPU时,模型部分算子可能运行在CPU设备上,将额外配置 `CPUDeviceInfo` 。 参数: - **thread_num** (int, 可选) - 设置运行时的CPU线程数量,该选项仅当推理后端为MindSpore Lite有效。 - **thread_affinity_core_list** (tuple[int], list[int], 可选) - 设置运行时的CPU绑核列表,该选项仅当推理后端为MindSpore Lite有效。 - **enable_parallel** (bool, 可选) - 设置运行时是否支持并行,该选项仅当推理后端为MindSpore Lite有效。 异常: - **RuntimeError** - 输入参数的类型或值无效。 .. py:method:: append_device_info(device_info) 用于添加一个用户自定义的设备配置。 参数: - **device_info** (Union[CPUDeviceInfo, GPUDeviceInfo, AscendDeviceInfo]) - 用户自定义设备配置,用户不指定设备配置时将使用默认值。可以为每个可能的设备自定义设备配置,系统根据实际的后端设备和推理包选择所需的设备信息。 异常: - **RuntimeError** - 输入参数的类型或值无效。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.GPUDeviceInfo.rst ================================================  .. py:class:: mindspore_serving.server.register.GPUDeviceInfo(**kwargs) 用于GPU设备配置。 参数: - **precision_mode** (str, 可选) - 推理精度选项,值可以是 ``"origin"`` 或 ``"fp16"`` , ``"origin"`` 表示以模型中指定精度进行推理, ``"fp16"`` 表示以FP16精度进行推理。默认值: ``"origin"`` 。 异常: - **RuntimeError** - 选项无效,或值类型不是字符串。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.Model.rst ================================================  .. py:class:: mindspore_serving.server.register.Model(model_key) 用于表示一个声明的模型。用户不应该直接构造 `Model` 对象,而是来自于 `declare_model` 或 `declare_servable` 的返回。 参数: - **model_key** (str) - 模型的唯一标志。 .. py:method:: call(*args, subgraph=0) 调用模型推理接口。 参数: - **args** - 实例的元组/列表,或一个实例的输入。 - **subgraph** (int, 可选) - 子图索引,当一个模型中存在多个子图时使用。默认值:``0``。 返回: 当输入参数 `args` 为元组/列表时,返回为instances的元组,当前输入 `args` 为一个实例的输入时,输出为这个实例的输出。 异常: - **RuntimeError** - 输入无效。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.add_stage.rst ================================================  .. py:function:: mindspore_serving.server.register.add_stage(stage, *args, outputs_count, batch_size=None, tag=None) 在服务的 `servable_config.py` 中,通过 `register_method` 装饰(wrap)Python函数定义服务的一个方法(method),本接口用于定义这个方法中的一个运行步骤(stage),可以是一个Python函数或者模型。 .. note:: 入参 `args` 的长度应等于函数或模型的输入个数。 参数: - **stage** (Union(function, Model)) - 用户定义的Python函数或由 `declare_model` 返回 `Model` 对象。 - **outputs_count** (int) - 用户定义的Python函数或模型的输出个数。 - **batch_size** (int, 可选) - 仅当stage是Python函数,且函数一次可以处理多实例时,此参数有效。默认值:``None``。 - ``None``,函数的输入将是一个实例的输入。 - ``0``,函数的输入将是实例的元组对象,实例元组的最大长度由服务器根据模型的batch大小确定。 - int value >= 1,函数的输入将是实例的元组对象,实例元组的最大长度是 `batch_size` 指定的值。 - **args** - stage输入占位符,可以是 `register_method` 装饰(wrap)的函数的输入或其他 `add_stage` 的输出。 `args` 的长度应等于Python函数或模型的输入数量。 - **tag** (str, 可选) - stage的自定义标签,如 ``"preprocess"``,默认值:``None``。 异常: - **RuntimeError** - 参数的类型或值无效,或发生其他错误。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.declare_model.rst ================================================  .. py:function:: mindspore_serving.server.register.declare_model(model_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None, context=None, config_file=None) 在服务的servable_config.py配置文件中使用,用于声明一个模型。 .. note:: 本接口需要在Serving服务器导入servable_config.py时生效。因此,建议在servable_config.py中全局使用此接口。 .. warning:: 参数 `options` 从1.6.0版本中已弃用,并将在未来版本中删除,请改用参数 `context` 。 参数: - **model_file** (Union[str, list[str]]) - 模型文件名。 - **model_format** (str) - 模型格式, ``"MindIR"`` 或 ``"MindIR_Lite"`` ,忽略大小写。 - **with_batch_dim** (bool, 可选) - 模型输入和输出的shape第一个维度是否是batch维度。默认值:``True``。 - **options** (Union[AclOptions, GpuOptions], 可选) - 模型的选项,支持 ``AclOptions`` 或 ``GpuOptions`` 。默认值:``None``。 - **context** (Context) - 用于配置设备环境的上下文信息,值为 ``None`` 时,Serving将依据部署的设备设置默认的设备上下文。默认值:``None``。 - **without_batch_dim_inputs** (Union[int, tuple[int], list[int]], 可选) - 当 `with_batch_dim` 为 ``True`` 时,用于指定shape不包括batch维度的模型输入的索引,比如模型输入0的shape不包括batch维度,则 `without_batch_dim_inputs` 可赋值为 `(0,)` 。默认值:``None``。 - **config_file** (str, 可选) - 用于设置混合精度推理的配置文件。文件路径可以是servable_config.py所在目录的绝对路径或相对路径。默认值:``None``。 返回: `Model` ,此模型的标识,可以用来调用 `Model.call` 或作为 `add_stage` 的输入。 异常: - **RuntimeError** - 参数的类型或值无效。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.register_method.rst ================================================  .. py:function:: mindspore_serving.server.register.register_method(output_names) 在服务的servable_config.py配置文件中使用,用于注册服务的方法,一个服务可以包括一个或多个方法,每个方法可基于模型提供不同的功能,客户端访问服务时需要指定服务和方法。MindSpore Serving支持由多个Python函数和多个模型组合串接提供服务。 .. note:: 本接口需要在Serving服务器导入servable_config.py时生效。因此,建议在servable_config.py中全局使用此接口。 此接口将定义方法的签名和处理流程。 签名包括方法名称、方法的输入和输出名称。当Serving客户端访问服务时,客户端需要指定服务名称、方法名称,并提供一个或多个推理实例。每个实例通过输入名称指定输入数据,并通过输出名称获取输出结果。 处理流程由一个或多个阶段(stage)组成,每个阶段可以是一个Python函数或模型。即,一个方法的处理流程可以包括一个或多个Python函数和一个或多个模型。此外,接口还定义了这些阶段之间的数据流。 参数: - **output_names** (Union[str, tuple[str], list[str]]) - 指定方法的输出名称。输入名称通过注册函数的参数名称指定。 异常: - **RuntimeError** - 参数的类型或值无效,或发生其他错误。 ================================================ FILE: docs/api/api_python/server/register/mindspore_serving.server.register.rst ================================================ 服务注册接口,在服务的servable_config.py配置文件中使用。如何配置servable_config.py文件,请查看 `通过配置模型提供Servable `_ 。 ================================================ FILE: engine/README.md ================================================ ## Overview An engine supports finetune and inference. ================================================ FILE: example/add_sub_pipeline/add_sub/servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """add model servable config""" import numpy as np from mindspore_serving.server import register def add_trans_datatype(x1, x2): """define preprocess, this example has two inputs and two outputs""" return x1.astype(np.float32), x2.astype(np.float32) def add_1(x): return x + 1 # when with_batch_dim is set to False, only 2x2 add is supported # when with_batch_dim is set to True(default), Nx2 add is supported, while N is viewed as batch # float32 inputs/outputs add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) sub_model = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) # register add_sub_only_model method in add_sub @register.register_method(output_names=["y"]) def add_sub_only_model(x1, x2, x3): # x1+x2-x3 """method add_sub_only_model data flow definition""" y = register.add_stage(add_model, x1, x2, outputs_count=1) y = register.add_stage(sub_model, y, x3, outputs_count=1) return y # register add_sub_complex method in add_sub @register.register_method(output_names=["y"]) def add_sub_complex(x1, x2, x3): # x1+x2+1-x3+1 """method add_sub_complex data flow definition""" x1, x2 = register.add_stage(add_trans_datatype, x1, x2, outputs_count=2) # cast input to float32 y = register.add_stage(add_model, x1, x2, outputs_count=1) y = register.add_stage(add_1, y, outputs_count=1) y = register.add_stage(sub_model, y, x3, outputs_count=1) y = register.add_stage(add_1, y, outputs_count=1) return y ================================================ FILE: example/add_sub_pipeline/export_model/add_sub_model.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """add model generator""" import os from shutil import copyfile import numpy as np import mindspore.context as context import mindspore.nn as nn import mindspore.ops as ops import mindspore as ms context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class AddNet(nn.Cell): """Define Net of add""" def __init__(self): super(AddNet, self).__init__() self.add = ops.Add() def construct(self, x_, y_): """construct add net""" return self.add(x_, y_) class SubNet(nn.Cell): """Define Net of sub""" def __init__(self): super(SubNet, self).__init__() self.sub = ops.Sub() def construct(self, x_, y_): """construct add net""" return self.sub(x_, y_) def export_net(): """Export add net of 2x2 + 2x2, and copy output model `tensor_add.mindir` and `tensor_sub.mindir` to directory ../add_sub/1""" x = np.ones([2, 2]).astype(np.float32) y = np.ones([2, 2]).astype(np.float32) add = AddNet() ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add', file_format='MINDIR') sub = SubNet() ms.export(sub, ms.Tensor(x), ms.Tensor(y), file_name='tensor_sub', file_format='MINDIR') dst_dir = '../add_sub/1' try: os.mkdir(dst_dir) except OSError: pass dst_file = os.path.join(dst_dir, 'tensor_add.mindir') copyfile('tensor_add.mindir', dst_file) print("copy tensor_add.mindir to " + dst_dir + " success") dst_file = os.path.join(dst_dir, 'tensor_sub.mindir') copyfile('tensor_sub.mindir', dst_file) print("copy tensor_sub.mindir to " + dst_dir + " success") if __name__ == "__main__": export_net() ================================================ FILE: example/add_sub_pipeline/serving_client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The client of example add_sub pipeline""" import numpy as np from mindspore_serving.client import Client def is_float_equal(left, right): """Check whether two float numbers are equal""" return (np.abs(left-right) < 0.0001).all() def run_add_sub_only_model(): """invoke servable add_sub method add_sub_only_model""" # x1+x2-x3 client = Client("127.0.0.1:5500", "add_sub", "add_sub_only_model") instances = [] # instance 1 x1 = np.asarray([[30, 30], [20, 20]]).astype(np.float32) x2 = np.asarray([[20, 20], [20, 20]]).astype(np.float32) x3 = np.asarray([[10, 10], [10, 10]]).astype(np.float32) instances.append({"x1": x1, "x2": x2, "x3": x3}) expect_y = x1 + x2 - x3 result = client.infer(instances) print(result) assert len(result) == len(instances) assert is_float_equal(result[0]["y"], expect_y) def run_add_sub_complex(): """invoke servable add_sub method add_sub_complex""" # x1+x2+1-x3+1 client = Client("127.0.0.1:5500", "add_sub", "add_sub_complex") instances = [] # instance 1 x1 = np.asarray([[30, 30], [20, 20]]).astype(np.float32) x2 = np.asarray([[20, 20], [20, 20]]).astype(np.float32) x3 = np.asarray([[10, 10], [10, 10]]).astype(np.float32) instances.append({"x1": x1, "x2": x2, "x3": x3}) expect_y = x1 + x2 + 1 - x3 + 1 result = client.infer(instances) print(result) assert len(result) == len(instances) assert is_float_equal(result[0]["y"], expect_y) if __name__ == '__main__': run_add_sub_only_model() run_add_sub_complex() ================================================ FILE: example/add_sub_pipeline/serving_server.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The server of example add_sub pipeline""" import os from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(__file__)) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="add_sub", device_ids=(0, 1)) server.start_servables(servable_configs=servable_config) server.start_grpc_server(address="127.0.0.1:5500") server.start_restful_server(address="127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: example/lenet/export_model/export_lenet.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """export Lenet for mnist dataset""" import os from shutil import copyfile from lenet.export import export_lenet if __name__ == '__main__': export_lenet() dst_dir = '../lenet/1' try: os.mkdir(dst_dir) except OSError: pass dst_file = os.path.join(dst_dir, 'lenet.mindir') copyfile('lenet.mindir', dst_file) ================================================ FILE: example/lenet/export_model/lenet/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Resnet export model""" ================================================ FILE: example/lenet/export_model/lenet/export.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """export checkpoint file into air, onnx, mindir models""" import os import numpy as np from easydict import EasyDict as ed import mindspore from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from .src.lenet import LeNet5 config = ed({ 'num_classes': 10, 'batch_size': 2, 'image_height': 32, 'image_width': 32 }) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(device_id=0) def export_lenet(): """define lenet network""" network = LeNet5(config.num_classes) # load network checkpoint cur_dir = os.path.dirname(os.path.realpath(__file__)) ckpt_file = os.path.join(cur_dir, 'lenet_ascend_v111_offical_cv_mnist_bs32_acc98.ckpt') param_dict = load_checkpoint(ckpt_file) load_param_into_net(network, param_dict) # export network inputs = Tensor(np.ones([config.batch_size, 1, config.image_height, config.image_width]), mindspore.float32) export(network, inputs, file_name="lenet", file_format="MINDIR") if __name__ == "__main__": export_lenet() ================================================ FILE: example/lenet/export_model/lenet/src/lenet.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """LeNet.""" import mindspore.nn as nn from mindspore.common.initializer import Normal class LeNet5(nn.Cell): """ Lenet network Args: num_class (int): Number of classes. Default: 10. num_channel (int): Number of channels. Default: 1. Returns: Tensor, output tensor Examples: >>> LeNet(num_class=10) """ def __init__(self, num_class=10, num_channel=1, include_top=True): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.include_top = include_top if self.include_top: self.flatten = nn.Flatten() self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) def construct(self, x): """Construct lenet""" x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) if not self.include_top: return x x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x ================================================ FILE: example/lenet/lenet/servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Lenet config python file""" from io import BytesIO import numpy as np from PIL import Image from mindspore_serving.server import register def preprocess_eager(image): """ Define preprocess, input is image numpy, return preprocess result. Return type can be numpy, str, bytes, int, float, or bool. Use MindData Eager, this image processing can also use other image processing library, likes numpy, PIL or cv2 etc. """ image = Image.open(BytesIO(image.tobytes())).convert('L').resize((32, 32), Image.ANTIALIAS) image = np.array(image, np.float32) image = image / 255.0 return image def postprocess_top1(score): """ Define postprocess. This example has one input and one output. The input is the numpy tensor of the score, and the output is the label str of top one. """ max_idx = np.argmax(score) return max_idx lenet_model = register.declare_model(model_file="lenet.mindir", model_format="MindIR") @register.register_method(output_names=["label"]) def classify_top1(image): """Define method `classify_top1` for servable `resnet50`. The input is `image` and the output is `lable`.""" x = register.add_stage(preprocess_eager, image, outputs_count=1) x = register.add_stage(lenet_model, x, outputs_count=1) x = register.add_stage(postprocess_top1, x, outputs_count=1) return x ================================================ FILE: example/lenet/serving_client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Client for lenet""" import os from mindspore_serving.client import Client def read_images(): """Read images for directory test_image""" images_buffer = [] image_files = [] for path, _, file_list in os.walk("./test_image/"): for file_name in file_list: image_file = os.path.join(path, file_name) image_files.append(image_file) for image_file in image_files: with open(image_file, "rb") as fp: images_buffer.append(fp.read()) return images_buffer, image_files def run_classify_top1(): """Client for servable lenet and method classify_top1""" print("run_classify_top1-----------") client = Client("localhost:5500", "lenet", "classify_top1") instances = [] images_buffer, image_files = read_images() for image in images_buffer: instances.append({"image": image}) result = client.infer(instances) print(result) for item, file in zip(result, image_files): print(f"file: {file}, result: {item['label']}") def run_classify_top1_async(): """Client for servable lenet and method classify_top1""" print("run_classify_top1_async-----------") client = Client("localhost:5500", "lenet", "classify_top1") instances = [] images_buffer, image_files = read_images() for image in images_buffer: instances.append({"image": image}) result_future = client.infer_async(instances) result = result_future.result() print(result) for item, file in zip(result, image_files): print(f"file: {file}, result: {item['label']}") def run_restful_classify_top1(): """RESTful Client for servable lenet and method classify_top1""" print("run_restful_classify_top1-----------") import base64 import requests import json instances = [] images_buffer, image_files = read_images() for image in images_buffer: base64_data = base64.b64encode(image).decode() instances.append({"image": {"b64": base64_data}}) instances_map = {"instances": instances} post_payload = json.dumps(instances_map) ip = "localhost" restful_port = 1500 servable_name = "lenet" method_name = "classify_top1" result = requests.post(f"http://{ip}:{restful_port}/model/{servable_name}:{method_name}", data=post_payload) print(result.text) result = json.loads(result.text) for item, file in zip(result["instances"], image_files): print(f"file: {file}, result: {item['label']}") if __name__ == '__main__': run_classify_top1() run_classify_top1_async() run_restful_classify_top1() ================================================ FILE: example/lenet/serving_server.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start Servable lenet""" import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="lenet", device_ids=(0, 1)) server.start_servables(config) server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: example/matmul_distributed/export_model/distributed_inference.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ '''distributed inference The sample can be run on Ascend 910 AI processor. ''' import numpy as np from net import Net from mindspore import context, Model, Tensor, export from mindspore.communication import init def test_inference(): """distributed inference after distributed training""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") init(backend_name="hccl") context.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel", device_num=8, group_ckpt_save_file="./group_config.pb") predict_data = create_predict_data() network = Net(matmul_size=(96, 16)) model = Model(network) model.infer_predict_layout(Tensor(predict_data)) export(model.predict_network, Tensor(predict_data), file_name="matmul", file_format="MINDIR") def create_predict_data(): """user-defined predict data""" inputs_np = np.random.randn(128, 96).astype(np.float32) return Tensor(inputs_np) ================================================ FILE: example/matmul_distributed/export_model/export_model.sh ================================================ #!/bin/bash EXEC_PATH=$(pwd) export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json export RANK_SIZE=8 rm -rf device* for ((i = 1; i < ${RANK_SIZE}; i++)); do mkdir device$i cp *.py ./device$i cd ./device$i export DEVICE_ID=$i export RANK_ID=$i echo "start inference for device $i" pytest -sv ./distributed_inference.py::test_inference >inference.log$i 2>&1 & cd ../ done mkdir device0 cp *.py ./device0 cd ./device0 export DEVICE_ID=0 export RANK_ID=0 echo "start inference for device 0" pytest -sv ./distributed_inference.py::test_inference >inference.log0 2>&1 if [ $? -eq 0 ]; then echo "inference success" else echo "inference failed" cat inference.log0 exit 2 fi cd ../ ls device*/ -l num=`ls device*/matmul.mindir -l | wc -l` if [ ${num} -ne 8 ] then echo "export matmul mindir failed" cat device0/inference.log0 exit 2 fi output_dir=../model rm -rf ${output_dir}/device* for ((i = 0; i < ${RANK_SIZE}; i++)); do mkdir -p ${output_dir}/device${i} cp device${i}/*.mindir ${output_dir}/device${i}/ cp device${i}/*.pb ${output_dir}/device${i}/ done echo "copy models success" ================================================ FILE: example/matmul_distributed/export_model/net.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ '''net The sample can be run on Ascend 910 AI processor. ''' import numpy as np from mindspore import Tensor, Parameter, ops from mindspore.nn import Cell class Net(Cell): """Net""" def __init__(self, matmul_size, transpose_a=False, transpose_b=False, strategy=None): """init""" super().__init__() matmul_np = np.full(matmul_size, 0.5, dtype=np.float32) self.matmul_weight = Parameter(Tensor(matmul_np)) self.matmul = ops.MatMul(transpose_a=transpose_a, transpose_b=transpose_b) self.neg = ops.Neg() if strategy is not None: self.matmul.shard(strategy) def construct(self, inputs): """construct""" x = self.matmul(inputs, self.matmul_weight) x = self.neg(x) return x ================================================ FILE: example/matmul_distributed/export_model/rank_table_8pcs.json ================================================ { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "127.0.0.1", "device": [ { "device_id": "0", "device_ip": "192.1.27.6", "rank_id": "0" }, { "device_id": "1", "device_ip": "192.2.27.6", "rank_id": "1" }, { "device_id": "2", "device_ip": "192.3.27.6", "rank_id": "2" }, { "device_id": "3", "device_ip": "192.4.27.6", "rank_id": "3" }, { "device_id": "4", "device_ip": "192.1.27.7", "rank_id": "4" }, { "device_id": "5", "device_ip": "192.2.27.7", "rank_id": "5" }, { "device_id": "6", "device_ip": "192.3.27.7", "rank_id": "6" }, { "device_id": "7", "device_ip": "192.4.27.7", "rank_id": "7" } ], "host_nic_ip": "reserve" } ], "status": "completed" } ================================================ FILE: example/matmul_distributed/matmul/servable_config.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Distributed matmul config python file""" from mindspore_serving.server import distributed from mindspore_serving.server import register model = distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False) @register.register_method(output_names=["y"]) def predict(x): y = register.add_stage(model, x, outputs_count=1) return y ================================================ FILE: example/matmul_distributed/rank_table_8pcs.json ================================================ { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "127.0.0.1", "device": [ { "device_id": "0", "device_ip": "192.1.27.6", "rank_id": "0" }, { "device_id": "1", "device_ip": "192.2.27.6", "rank_id": "1" }, { "device_id": "2", "device_ip": "192.3.27.6", "rank_id": "2" }, { "device_id": "3", "device_ip": "192.4.27.6", "rank_id": "3" }, { "device_id": "4", "device_ip": "192.1.27.7", "rank_id": "4" }, { "device_id": "5", "device_ip": "192.2.27.7", "rank_id": "5" }, { "device_id": "6", "device_ip": "192.3.27.7", "rank_id": "6" }, { "device_id": "7", "device_ip": "192.4.27.7", "rank_id": "7" } ], "host_nic_ip": "reserve" } ], "status": "completed" } ================================================ FILE: example/matmul_distributed/serving_agent.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start agents of Distributed Servable matmul""" from mindspore_serving.server import distributed def start_agents(): """Start all the agents in current machine""" model_files = [] group_configs = [] for i in range(8): model_files.append(f"model/device{i}/matmul.mindir") group_configs.append(f"model/device{i}/group_config.pb") distributed.startup_agents(distributed_address="127.0.0.1:6200", model_files=model_files, group_config_files=group_configs) if __name__ == '__main__': start_agents() ================================================ FILE: example/matmul_distributed/serving_client.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Client for distributed matmul""" import numpy as np from mindspore_serving.client import Client def run_matmul(): """Run client of distributed matmul""" client = Client("localhost:5500", "matmul", "predict") instance = {"x": np.ones((128, 96), np.float32)} result = client.infer(instance) print("result:\n", result) assert len(result) == 1 assert "y" in result[0] if __name__ == '__main__': run_matmul() ================================================ FILE: example/matmul_distributed/serving_server.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start Distributed Servable matmul""" import os import sys from mindspore_serving import server from mindspore_serving.server import distributed def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) distributed.start_servable(servable_dir, "matmul", rank_table_json_file="rank_table_8pcs.json", version_number=1, distributed_address="127.0.0.1:6200") server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: example/matmul_multi_subgraphs/export_model/export_matmul.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ '''net The sample can be run on Ascend 910 AI processor. ''' import os from shutil import copyfile import numpy as np import mindspore.context as context from mindspore import Tensor, Parameter, ops, export from mindspore.nn import Cell class Net(Cell): """Net""" def __init__(self, matmul_size, init_val, transpose_a=False, transpose_b=False): """init""" super().__init__() matmul_np = np.full(matmul_size, init_val, dtype=np.float32) self.matmul_weight = Parameter(Tensor(matmul_np)) self.matmul = ops.MatMul(transpose_a=transpose_a, transpose_b=transpose_b) self.sum = ops.ReduceSum() def construct(self, inputs): """construct""" x = self.matmul(inputs, self.matmul_weight) x = self.sum(x, 0) return x def export_net(): """Export matmul net , and copy output model `matmul_0.mindir` and `matmul_1.mindir` to directory ../matmul/1""" context.set_context(mode=context.GRAPH_MODE) network = Net(matmul_size=(96, 16), init_val=0.5) # subgraph 0: 128,96 matmul 16,96 -> 128,16 reduce sum axis 0-> 16 predict_data = np.random.randn(128, 96).astype(np.float32) # pylint: disable=protected-access export(network, Tensor(predict_data), file_name="matmul_0", file_format="MINDIR") # subgraph 1: 8,96 matmul 16,96 -> 8,16 reduce sum axis 0-> 16 predict_data = np.random.randn(8, 96).astype(np.float32) # pylint: disable=protected-access export(network, Tensor(predict_data), file_name="matmul_1", file_format="MINDIR") dst_dir = '../matmul/1' try: os.mkdir(dst_dir) except OSError: pass dst_file = os.path.join(dst_dir, 'matmul_0.mindir') copyfile('matmul_0.mindir', dst_file) print("copy matmul_0.mindir to " + dst_dir + " success") dst_file = os.path.join(dst_dir, 'matmul_1.mindir') copyfile('matmul_1.mindir', dst_file) print("copy matmul_1.mindir to " + dst_dir + " success") if __name__ == "__main__": export_net() ================================================ FILE: example/matmul_multi_subgraphs/matmul/servable_config.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Distributed matmul config python file""" from mindspore_serving.server import register model = register.declare_model(model_file=["matmul_0.mindir", "matmul_1.mindir"], model_format="MindIR", with_batch_dim=False) def process(x, y): z1 = model.call(x, subgraph=0) # 128,96 matmul 16,96 -> reduce sum axis 0-> 16 z2 = model.call(y, subgraph=1) # 8,96 matmul 16,96 -> reduce sum axis 0-> 16 return z1 + z2 @register.register_method(output_names=["z"]) def predict(x, y): z = register.add_stage(process, x, y, outputs_count=1) return z ================================================ FILE: example/matmul_multi_subgraphs/serving_client.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Client for distributed matmul""" import numpy as np from mindspore_serving.client import Client def run_matmul(): """Run client of distributed matmul""" client = Client("localhost:5500", "matmul", "predict") instance = {"x": np.ones((128, 96), np.float32), "y": np.ones((8, 96), np.float32)} result = client.infer(instance) print("result:\n", result) assert "z" in result[0] if __name__ == '__main__': run_matmul() ================================================ FILE: example/matmul_multi_subgraphs/serving_server.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start Distributed Servable matmul""" import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="matmul", device_ids=(0, 1)) server.start_servables(servable_config) server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: example/resnet/export_model/export_resnet.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """export resnet50 for cifar10 dataset""" import os import sys from shutil import copyfile from resnet.export import export_resnet if __name__ == '__main__': if len(sys.argv) > 1 and sys.argv[1] == 'False': # python export_resnet.py False ckpt_file = None else: ckpt_file = "resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt" if not os.path.exists(ckpt_file): print("downloading resnet50 cifar10 checkpoint---------------------------------") os.system(f"wget https://download.mindspore.cn/model_zoo/r1.1/" f"resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92/{ckpt_file} --no-check-certificate") print("end downloading resnet50 cifar10 checkpoint---------------------------------") export_resnet('resnet50_cifar10', ckpt_file, 'resnet50_1b_cifar10') dst_dir = '../resnet50/1' try: os.mkdir(dst_dir) except OSError: pass dst_file = os.path.join(dst_dir, 'resnet50_1b_cifar10.mindir') copyfile('resnet50_1b_cifar10.mindir', dst_file) ================================================ FILE: example/resnet/export_model/resnet/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Resnet export model""" ================================================ FILE: example/resnet/export_model/resnet/export.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ ##############export checkpoint file into air and onnx models################# python export.py """ import argparse import numpy as np from mindspore import Tensor, export from mindspore import load_checkpoint, load_param_into_net def export_resnet(network_dataset, ckpt_file, output_file): """export resnet""" if network_dataset == 'resnet50_cifar10': from .src.config import config1 as config from .src.resnet import resnet50 as resnet elif network_dataset == 'resnet50_imagenet2012': from .src.config import config2 as config from .src.resnet import resnet50 as resnet elif network_dataset == 'resnet101_imagenet2012': from .src.config import config3 as config from .src.resnet import resnet101 as resnet elif network_dataset == 'se-resnet50_imagenet2012': from .src.config import config4 as config from .src.resnet import se_resnet50 as resnet else: raise ValueError("network and dataset is not support.") net = resnet(config.class_num) if ckpt_file is not None: param_dict = load_checkpoint(ckpt_file) load_param_into_net(net, param_dict) input_arr = Tensor(np.zeros([1, 3, 224, 224], np.float32)) export(net, input_arr, file_name=output_file, file_format="MINDIR") if __name__ == '__main__': parser = argparse.ArgumentParser(description='resnet export') parser.add_argument('--network_dataset', type=str, default='resnet50_cifar10', choices=['resnet50_cifar10', 'resnet50_imagenet2012', 'resnet101_imagenet2012', "se-resnet50_imagenet2012"], help='network and dataset name.') parser.add_argument('--ckpt_file', type=str, default='', help='resnet ckpt file.') parser.add_argument('--output_file', type=str, default='', help='resnet output air name.') args_opt = parser.parse_args() export_resnet(args_opt.network_dataset, args_opt.ckpt_file, args_opt.output_file) ================================================ FILE: example/resnet/export_model/resnet/src/config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ network config setting, will be used in train.py and eval.py """ from easydict import EasyDict as ed # config for resent50, cifar10 config1 = ed({ "class_num": 10, "batch_size": 32, "loss_scale": 1024, "momentum": 0.9, "weight_decay": 1e-4, "epoch_size": 90, "pretrain_epoch_size": 0, "save_checkpoint": True, "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 5, "lr_decay_mode": "poly", "lr_init": 0.01, "lr_end": 0.00001, "lr_max": 0.1 }) # config for resnet50, imagenet2012 config2 = ed({ "class_num": 1001, "batch_size": 256, "loss_scale": 1024, "momentum": 0.9, "weight_decay": 1e-4, "epoch_size": 90, "pretrain_epoch_size": 0, "save_checkpoint": True, "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, "lr_decay_mode": "linear", "use_label_smooth": True, "label_smooth_factor": 0.1, "lr_init": 0, "lr_max": 0.8, "lr_end": 0.0 }) # config for resent101, imagenet2012 config3 = ed({ "class_num": 1001, "batch_size": 32, "loss_scale": 1024, "momentum": 0.9, "weight_decay": 1e-4, "epoch_size": 120, "pretrain_epoch_size": 0, "save_checkpoint": True, "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, "lr_decay_mode": "cosine", "use_label_smooth": True, "label_smooth_factor": 0.1, "lr": 0.1 }) # config for se-resnet50, imagenet2012 config4 = ed({ "class_num": 1001, "batch_size": 32, "loss_scale": 1024, "momentum": 0.9, "weight_decay": 1e-4, "epoch_size": 28, "train_epoch_size": 24, "pretrain_epoch_size": 0, "save_checkpoint": True, "save_checkpoint_epochs": 4, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 3, "lr_decay_mode": "cosine", "use_label_smooth": True, "label_smooth_factor": 0.1, "lr_init": 0.0, "lr_max": 0.3, "lr_end": 0.0001 }) ================================================ FILE: example/resnet/export_model/resnet/src/resnet.py ================================================ # Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ResNet.""" import math import numpy as np from scipy.stats import truncnorm import mindspore as ms from mindspore import nn from mindspore import ops from mindspore import Tensor def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): fan_in = in_channel * kernel_size * kernel_size scale = 1.0 scale /= max(1., fan_in) stddev = (scale ** 0.5) / .87962566103423978 mu, sigma = 0, stddev weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) return Tensor(weight, dtype=ms.float32) def _weight_variable(shape, factor=0.01): init_value = np.random.randn(*shape).astype(np.float32) * factor return Tensor(init_value) def calculate_gain(nonlinearity, param=None): """calculate_gain""" linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] res = 0 if nonlinearity in linear_fns or nonlinearity == 'sigmoid': res = 1 elif nonlinearity == 'tanh': res = 5.0 / 3 elif nonlinearity == 'relu': res = math.sqrt(2.0) elif nonlinearity == 'leaky_relu': if param is None: negative_slope = 0.01 elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): # True/False are instances of int, hence check above negative_slope = param else: raise ValueError("negative_slope {} not a valid number".format(param)) res = math.sqrt(2.0 / (1 + negative_slope ** 2)) else: raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) return res def _calculate_fan_in_and_fan_out(tensor): """_calculate_fan_in_and_fan_out""" dimensions = len(tensor) if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") if dimensions == 2: # Linear fan_in = tensor[1] fan_out = tensor[0] else: num_input_fmaps = tensor[1] num_output_fmaps = tensor[0] receptive_field_size = 1 if dimensions > 2: receptive_field_size = tensor[2] * tensor[3] fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out def _calculate_correct_fan(tensor, mode): mode = mode.lower() valid_modes = ['fan_in', 'fan_out'] if mode not in valid_modes: raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) return fan_in if mode == 'fan_in' else fan_out def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): fan = _calculate_correct_fan(inputs_shape, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) return np.random.normal(0, std, size=inputs_shape).astype(np.float32) def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'): fan = _calculate_correct_fan(inputs_shape, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) else: weight_shape = (out_channel, in_channel, 3, 3) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) if res_base: return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=weight) return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) else: weight_shape = (out_channel, in_channel, 1, 1) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) if res_base: return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, pad_mode='pad', weight_init=weight) return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) else: weight_shape = (out_channel, in_channel, 7, 7) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) if res_base: return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight) return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) def _bn(channel, res_base=False): if res_base: return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1, gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) def _bn_last(channel): return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) def _fc(in_channel, out_channel, use_se=False): if use_se: weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel) weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=ms.float32) else: weight_shape = (out_channel, in_channel) weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) class ResidualBlock(nn.Cell): """ ResNet V1 residual block definition. Args: in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. Default: 1. use_se (bool): Enable SE-ResNet50 net. Default: False. se_block(bool): Use se block in SE-ResNet50 net. Default: False. Returns: Tensor, output tensor. Examples: >>> ResidualBlock(3, 256, stride=2) """ expansion = 4 def __init__(self, in_channel, out_channel, stride=1, use_se=False, se_block=False): super(ResidualBlock, self).__init__() self.stride = stride self.use_se = use_se self.se_block = se_block channel = out_channel // self.expansion self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) self.bn1 = _bn(channel) if self.use_se and self.stride != 1: self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) else: self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) self.bn2 = _bn(channel) self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) self.bn3 = _bn_last(out_channel) if self.se_block: self.se_global_pool = ops.ReduceMean(keep_dims=False) self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se) self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se) self.se_sigmoid = nn.Sigmoid() self.se_mul = ops.Mul() self.relu = nn.ReLU() self.down_sample = False if stride != 1 or in_channel != out_channel: self.down_sample = True self.down_sample_layer = None if self.down_sample: if self.use_se: if stride == 1: self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, use_se=self.use_se), _bn(out_channel)]) else: self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), _conv1x1(in_channel, out_channel, 1, use_se=self.use_se), _bn(out_channel)]) else: self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, use_se=self.use_se), _bn(out_channel)]) def construct(self, x): """Construct ResidualBlock""" identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) if self.use_se and self.stride != 1: out = self.e2(out) else: out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.se_block: out_se = out out = self.se_global_pool(out, (2, 3)) out = self.se_dense_0(out) out = self.relu(out) out = self.se_dense_1(out) out = self.se_sigmoid(out) out = out.reshape(out.shape(out) + (1, 1)) out = self.se_mul(out, out_se) if self.down_sample: identity = self.down_sample_layer(identity) out = out + identity out = self.relu(out) return out class ResidualBlockBase(nn.Cell): """ ResNet V1 residual block definition. Args: in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. Default: 1. use_se (bool): Enable SE-ResNet50 net. Default: False. se_block(bool): Use se block in SE-ResNet50 net. Default: False. res_base (bool): Enable parameter setting of resnet18. Default: True. Returns: Tensor, output tensor. Examples: >>> ResidualBlockBase(3, 256, stride=2) """ # pylint: disable=unused-argument def __init__(self, in_channel, out_channel, stride=1, use_se=False, se_block=False, res_base=True): super(ResidualBlockBase, self).__init__() self.res_base = res_base self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base) self.bn1d = _bn(out_channel) self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base) self.bn2d = _bn(out_channel) self.relu = nn.ReLU() self.down_sample = False if stride != 1 or in_channel != out_channel: self.down_sample = True self.down_sample_layer = None if self.down_sample: self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, use_se=use_se, res_base=self.res_base), _bn(out_channel, res_base)]) def construct(self, x): """Construct ResidualBlockBase""" identity = x out = self.conv1(x) out = self.bn1d(out) out = self.relu(out) out = self.conv2(out) out = self.bn2d(out) if self.down_sample: identity = self.down_sample_layer(identity) out = out + identity out = self.relu(out) return out class ResNet(nn.Cell): """ ResNet architecture. Args: block (Cell): Block for network. layer_nums (list): Numbers of block in different layers. in_channels (list): Input channel in each layer. out_channels (list): Output channel in each layer. strides (list): Stride size in each layer. num_classes (int): The number of classes that the training images are belonging to. use_se (bool): Enable SE-ResNet50 net. Default: False. se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. res_base (bool): Enable parameter setting of resnet18. Default: False. Returns: Tensor, output tensor. Examples: >>> ResNet(ResidualBlock, >>> [3, 4, 6, 3], >>> [64, 256, 512, 1024], >>> [256, 512, 1024, 2048], >>> [1, 2, 2, 2], >>> 10) """ def __init__(self, block, layer_nums, in_channels, out_channels, strides, num_classes, use_se=False, res_base=False): super(ResNet, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") self.use_se = use_se self.res_base = res_base self.se_block = False if self.use_se: self.se_block = True if self.use_se: self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) self.bn1_0 = _bn(32) self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) self.bn1_1 = _bn(32) self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) else: self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base) self.bn1 = _bn(64, self.res_base) self.relu = ops.ReLU() if self.res_base: self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1))) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid") else: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.layer1 = self._make_layer(block, layer_nums[0], in_channel=in_channels[0], out_channel=out_channels[0], stride=strides[0], use_se=self.use_se) self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1], out_channel=out_channels[1], stride=strides[1], use_se=self.use_se) self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2], out_channel=out_channels[2], stride=strides[2], use_se=self.use_se, se_block=self.se_block) self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3], out_channel=out_channels[3], stride=strides[3], use_se=self.use_se, se_block=self.se_block) self.mean = ops.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): """ Make stage network of ResNet. Args: block (Cell): Resnet block. layer_num (int): Layer number. in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. se_block(bool): Use se block in SE-ResNet50 net. Default: False. Returns: SequentialCell, the output layer. Examples: >>> _make_layer(ResidualBlock, 3, 128, 256, 2) """ layers = [] resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) layers.append(resnet_block) if se_block: for _ in range(1, layer_num - 1): resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) layers.append(resnet_block) resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) layers.append(resnet_block) else: for _ in range(1, layer_num): resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) layers.append(resnet_block) return nn.SequentialCell(layers) def construct(self, x): """Construct Resnet""" if self.use_se: x = self.conv1_0(x) x = self.bn1_0(x) x = self.relu(x) x = self.conv1_1(x) x = self.bn1_1(x) x = self.relu(x) x = self.conv1_2(x) else: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.res_base: x = self.pad(x) c1 = self.maxpool(x) c2 = self.layer1(c1) c3 = self.layer2(c2) c4 = self.layer3(c3) c5 = self.layer4(c4) out = self.mean(c5, (2, 3)) out = self.flatten(out) out = self.end_point(out) return out def resnet18(class_num=10): """ Get ResNet18 neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of ResNet18 neural network. Examples: >>> net = resnet18(10) """ return ResNet(ResidualBlockBase, [2, 2, 2, 2], [64, 64, 128, 256], [64, 128, 256, 512], [1, 2, 2, 2], class_num, res_base=True) def resnet34(class_num=10): """ Get ResNet34 neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of ResNet34 neural network. Examples: >>> net = resnet18(10) """ return ResNet(ResidualBlockBase, [3, 4, 6, 3], [64, 64, 128, 256], [64, 128, 256, 512], [1, 2, 2, 2], class_num, res_base=True) def resnet50(class_num=10): """ Get ResNet50 neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of ResNet50 neural network. Examples: >>> net = resnet50(10) """ return ResNet(ResidualBlock, [3, 4, 6, 3], [64, 256, 512, 1024], [256, 512, 1024, 2048], [1, 2, 2, 2], class_num) def se_resnet50(class_num=1001): """ Get SE-ResNet50 neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of SE-ResNet50 neural network. Examples: >>> net = se-resnet50(1001) """ return ResNet(ResidualBlock, [3, 4, 6, 3], [64, 256, 512, 1024], [256, 512, 1024, 2048], [1, 2, 2, 2], class_num, use_se=True) def resnet101(class_num=1001): """ Get ResNet101 neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of ResNet101 neural network. Examples: >>> net = resnet101(1001) """ return ResNet(ResidualBlock, [3, 4, 23, 3], [64, 256, 512, 1024], [256, 512, 1024, 2048], [1, 2, 2, 2], class_num) ================================================ FILE: example/resnet/resnet50/servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Resnet50 cifar10 config python file""" import numpy as np import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as TC import mindspore.dataset.vision.c_transforms as VC from mindspore_serving.server import register # cifar 10 idx_2_label = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] def preprocess_eager(image): """ Define preprocess, input is image numpy, return preprocess result. Return type can be numpy, str, bytes, int, float, or bool. Use MindData Eager, this image processing can also use other image processing library, likes numpy, PIL or cv2 etc. """ image_size = 224 mean = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255] std = [0.2023 * 255, 0.1994 * 255, 0.2010 * 255] decode = VC.Decode() resize = VC.Resize([image_size, image_size]) normalize = VC.Normalize(mean=mean, std=std) hwc2chw = VC.HWC2CHW() image = decode(image) image = resize(image) image = normalize(image) image = hwc2chw(image) return image def preprocess_batch(instances): """ Define preprocess pipeline, the function arg is multi instances, every instance is tuple of inputs. This example has one input and one output. Use MindData Pipeline. """ def generator_func(): for instance in instances: image = instance[0] yield (image,) resnet_ds = ds.GeneratorDataset(generator_func, ["image"], shuffle=False) image_size = 224 mean = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255] std = [0.2023 * 255, 0.1994 * 255, 0.2010 * 255] resnet_ds = resnet_ds.map(operations=VC.Decode(), input_columns="image", num_parallel_workers=8) trans = [ VC.Resize([image_size, image_size]), VC.Normalize(mean=mean, std=std), VC.HWC2CHW() ] resnet_ds = resnet_ds.map(operations=TC.Compose(trans), input_columns="image", num_parallel_workers=2) for data in resnet_ds.create_dict_iterator(num_epochs=1): image_result = data["image"] yield (image_result,) def postprocess_top1(score): """ Define postprocess. This example has one input and one output. The input is the numpy tensor of the score, and the output is the label str of top one. """ max_idx = np.argmax(score) return idx_2_label[max_idx] def postprocess_top5(score): """ Define postprocess. This example has one input and two outputs. The input is the numpy tensor of the score. The first output is the str joined by labels of top five, and the second output is the score tensor of the top five. """ idx = np.argsort(score)[::-1][:5] # top 5 ret_label = [idx_2_label[i] for i in idx] ret_score = score[idx] return ";".join(ret_label), ret_score resnet_model = register.declare_model(model_file="resnet50_1b_cifar10.mindir", model_format="MindIR") def call_resnet_model(image): """call model with only one instance a time""" image = preprocess_eager(image) score = resnet_model.call(image) # for only one instance return postprocess_top1(score) def call_resnet_model_batch(instances): """call model with multiply instances a time""" input_instances = [] for instance in instances: image = instance[0] # only one input image = preprocess_eager(image) # [3,224,224] input_instances.append([image]) output_instances = resnet_model.call(input_instances) # for multiply instances for instance in output_instances: output = instance[0] # only one output for each instance output = postprocess_top1(output) yield output @register.register_method(output_names=["label"]) def classify_top1_batch(image): """Define method `classify_top1` for servable `resnet50`. The input is `image` and the output is `lable`.""" x = register.add_stage(preprocess_batch, image, outputs_count=1, batch_size=1024) x = register.add_stage(resnet_model, x, outputs_count=1) x = register.add_stage(postprocess_top1, x, outputs_count=1) return x @register.register_method(output_names=["label"]) def classify_top1(image): # pipeline: preprocess_eager/postprocess_top1, model """Define method `classify_top1` for servable `resnet50`. The input is `image` and the output is `label`. """ x = register.add_stage(preprocess_eager, image, outputs_count=1) x = register.add_stage(resnet_model, x, outputs_count=1) x = register.add_stage(postprocess_top1, x, outputs_count=1) return x @register.register_method(output_names=["label"]) def classify_top1_v2(image): # without pipeline, call model with only one instance a time """Define method `classify_top1_v2` for servable `resnet50`. The input is `image` and the output is `label`. """ label = register.add_stage(call_resnet_model, image, outputs_count=1) return label @register.register_method(output_names=["label"]) def classify_top1_v3(image): # without pipeline, call model with maximum 32 instances a time """Define method `classify_top1_v2` for servable `resnet50`. The input is `image` and the output is `label`. """ label = register.add_stage(call_resnet_model_batch, image, outputs_count=1, batch_size=32) return label @register.register_method(output_names=["label", "score"]) def classify_top5(image): """Define method `classify_top5` for servable `resnet50`. The input is `image` and the output is `label` and `score`. """ x = register.add_stage(preprocess_eager, image, outputs_count=1) x = register.add_stage(resnet_model, x, outputs_count=1) label, score = register.add_stage(postprocess_top5, x, outputs_count=2) return label, score ================================================ FILE: example/resnet/serving_client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Client for resnet50""" import os from mindspore_serving.client import Client def read_images(): """Read images for directory test_image""" image_files = [] images_buffer = [] for path, _, file_list in os.walk("./test_image/"): for file_name in file_list: image_file = os.path.join(path, file_name) image_files.append(image_file) for image_file in image_files: with open(image_file, "rb") as fp: images_buffer.append(fp.read()) return image_files, images_buffer def run_classify_top1(method_name): """Client for servable resnet50 and method classify_top1[v1,v2,v3]""" print(f"\n--------------run_{method_name}----------") client = Client("localhost:5500", "resnet50", method_name) instances = [] image_files, images_buffer = read_images() for image in images_buffer: instances.append({"image": image}) result = client.infer(instances) print(result) for file, label in zip(image_files, result): print(f"{file}, label: {label['label']}") def run_classify_top5(): """Client for servable resnet50 and method classify_top5""" print("\n--------------run_classify_top5-----------") client = Client("localhost:5500", "resnet50", "classify_top5") instances = [] image_files, images_buffer = read_images() for image in images_buffer: instances.append({"image": image}) # input `image` result = client.infer(instances) print(result) for file, result_item in zip(image_files, result): # result for every image label = result_item["label"] # result `label` score = result_item["score"] # result `score` print("file:", file) print("label result:", label) print("score result:", score) def run_classify_top5_async(): """Client for servable resnet50 and method classify_top5""" print("\n--------------run_classify_top5_async-----------") client = Client("localhost:5500", "resnet50", "classify_top5") instances = [] image_files, images_buffer = read_images() for image in images_buffer: instances.append({"image": image}) # input `image` result_future = client.infer_async(instances) result = result_future.result() print(result) for file, result_item in zip(image_files, result): # result for every image label = result_item["label"] # result `label` score = result_item["score"] # result `score` print("file:", file) print("label result:", label) print("score result:", score) def run_restful_classify_top1(): """RESTful Client for servable resnet50 and method classify_top1""" print("\n--------------run_restful_classify_top1-----------") import base64 import requests import json instances = [] image_files, images_buffer = read_images() for image in images_buffer: base64_data = base64.b64encode(image).decode() instances.append({"image": {"b64": base64_data}}) instances_map = {"instances": instances} post_payload = json.dumps(instances_map) ip = "localhost" restful_port = 1500 servable_name = "resnet50" method_name = "classify_top1" result = requests.post(f"http://{ip}:{restful_port}/model/{servable_name}:{method_name}", data=post_payload) print(result.text) result = json.loads(result.text) for file, label in zip(image_files, result['instances']): print(f"{file}, label: {label['label']}") if __name__ == '__main__': run_classify_top1("classify_top1_batch") run_classify_top1("classify_top1") # preprocess eager, pipeline run_classify_top1("classify_top1_v2") # preprocess eager, without pipeline run_classify_top1("classify_top1_v3") # preprocess eager, without pipeline run_classify_top5() run_restful_classify_top1() run_classify_top5_async() ================================================ FILE: example/resnet/serving_server.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start Servable resnet50""" import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) # Total 4 worker, one worker occupy device 0, the model inference tasks of other workers are forwarded to the worker # that occupies the device. config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="resnet50", device_ids=0, num_parallel_workers=4) server.start_servables(config) server.start_grpc_server("127.0.0.1:5500") server.start_restful_server("127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: example/tensor_add/add/servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """add model servable config""" import numpy as np from mindspore_serving.server import register def add_trans_datatype(x1, x2): """define preprocess, this example has two inputs and two outputs""" return x1.astype(np.float32), x2.astype(np.float32) # when with_batch_dim is set to False, only 2x2 add is supported # when with_batch_dim is set to True(default), Nx2 add is supported, while N is viewed as batch # float32 inputs/outputs model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) # register add_common method in add @register.register_method(output_names=["y"]) def add_common(x1, x2): # only support float32 inputs """method add_common data flow definition, only call model""" y = register.add_stage(model, x1, x2, outputs_count=1) return y # register add_cast method in add @register.register_method(output_names=["y"]) def add_cast(x1, x2): """method add_cast data flow definition, only preprocessing and call model""" x1, x2 = register.add_stage(add_trans_datatype, x1, x2, outputs_count=2) # cast input to float32 y = register.add_stage(model, x1, x2, outputs_count=1) return y ================================================ FILE: example/tensor_add/export_model/add_model.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """add model generator""" import os from shutil import copyfile import numpy as np import mindspore.context as context import mindspore.nn as nn import mindspore.ops as ops import mindspore as ms context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): """Define Net of add""" def __init__(self): super(Net, self).__init__() self.add = ops.Add() def construct(self, x_, y_): """construct add net""" return self.add(x_, y_) def export_net(): """Export add net of 2x2 + 2x2, and copy output model `tensor_add.mindir` to directory ../add/1""" x = np.ones([2, 2]).astype(np.float32) y = np.ones([2, 2]).astype(np.float32) add = Net() ms.export(add, ms.Tensor(x), ms.Tensor(y), file_name='tensor_add', file_format='MINDIR') dst_dir = '../add/1' try: os.mkdir(dst_dir) except OSError: pass dst_file = os.path.join(dst_dir, 'tensor_add.mindir') copyfile('tensor_add.mindir', dst_file) print("copy tensor_add.mindir to " + dst_dir + " success") if __name__ == "__main__": export_net() ================================================ FILE: example/tensor_add/serving_client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The client of example add""" import numpy as np from mindspore_serving.client import Client def run_add_common(): """invoke servable add method add_common""" client = Client("127.0.0.1:5500", "add", "add_common") instances = [] # instance 1 x1 = np.asarray([[1, 1], [1, 1]]).astype(np.float32) x2 = np.asarray([[1, 1], [1, 1]]).astype(np.float32) instances.append({"x1": x1, "x2": x2}) # instance 2 x1 = np.asarray([[2, 2], [2, 2]]).astype(np.float32) x2 = np.asarray([[2, 2], [2, 2]]).astype(np.float32) instances.append({"x1": x1, "x2": x2}) # instance 3 x1 = np.asarray([[3, 3], [3, 3]]).astype(np.float32) x2 = np.asarray([[3, 3], [3, 3]]).astype(np.float32) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) def run_add_cast(): """invoke servable add method add_cast""" client = Client("127.0.0.1:5500", "add", "add_cast") instances = [] x1 = np.ones((2, 2), np.int32) x2 = np.ones((2, 2), np.int32) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) def post_restful(address, servable_name, method_name, json_instances, version_number=None): """construct and post restful request""" import json import requests instances_map = {"instances": json_instances} post_payload = json.dumps(instances_map) print("request:", post_payload[:200]) if version_number is not None: request_url = f"http://{address}/model/{servable_name}/version/{version_number}:{method_name}" result = requests.post(request_url, data=post_payload) else: request_url = f"http://{address}/model/{servable_name}:{method_name}" result = requests.post(request_url, data=post_payload) print("result", result.text[:200]) result = json.loads(result.text) return result def run_add_restful(): """run restful request: invoke servable add method add_common""" # Client print("begin to run add restful.") instances = [] x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) result = post_restful("localhost:1500", "add", "add_common", instances) print(result) if __name__ == '__main__': run_add_common() run_add_cast() run_add_restful() ================================================ FILE: example/tensor_add/serving_client_with_check.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The client of example add with result check""" import json import requests import numpy as np from mindspore_serving.client import Client def check_result(result, y_data_list): """check grpc output result""" assert len(result) == len(y_data_list) for result_item, y_data in zip(result, y_data_list): assert (np.abs(result_item["y"] - y_data) < 0.00001).all() def run_add_common(): """invoke servable add method add_common""" client = Client("localhost:5500", "add", "add_common") instances = [] instance_count = 3 y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) print(result) check_result(result, y_data_list) def run_add_cast(): """invoke servable add method add_cast""" client = Client("localhost:5500", "add", "add_cast") instances = [] y_data_list = [] x1 = np.ones((2, 2), np.int32) x2 = np.ones((2, 2), np.int32) instances.append({"x1": x1, "x2": x2}) y_data_list.append((x1 + x2).astype(np.float32)) result = client.infer(instances) print(result) check_result(result, y_data_list) def post_restful(address, servable_name, method_name, json_instances, version_number=None): """construct post restful request""" instances_map = {"instances": json_instances} post_payload = json.dumps(instances_map) print("request:", post_payload[:200]) if version_number is not None: request_url = f"http://{address}/model/{servable_name}/version/{version_number}:{method_name}" result = requests.post(request_url, data=post_payload) else: request_url = f"http://{address}/model/{servable_name}:{method_name}" result = requests.post(request_url, data=post_payload) print("result", result.text[:200]) result = json.loads(result.text) return result def check_number_result(result, y_data_list, output_name="y"): """check restful output result""" result = result["instances"] assert len(result) == len(y_data_list) for result_item, expected_item in zip(result, y_data_list): result_item = np.array(result_item[output_name]) print("result", result_item) print("expect:", expected_item) assert result_item.shape == expected_item.shape assert (np.abs(result_item - expected_item) < 0.001).all() def run_add_restful(): """run restful request: invoke servable add method add_common""" # Client print("begin to run add restful.") y_data_list = [] instances = [] x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) y_data_list.append((x1 + x2).astype(np.float32)) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) result = post_restful("localhost:1500", "add", "add_common", instances) check_number_result(result, y_data_list) if __name__ == '__main__': run_add_common() run_add_cast() run_add_restful() ================================================ FILE: example/tensor_add/serving_server.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The server of example add""" import os import sys from mindspore_serving import server def start(): servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="add", device_ids=(0, 1)) server.start_servables(servable_configs=servable_config) server.start_grpc_server(address="127.0.0.1:5500") server.start_restful_server(address="127.0.0.1:1500") if __name__ == "__main__": start() ================================================ FILE: mindspore_serving/CMakeLists.txt ================================================ # This branch assumes that gRPC and all its dependencies are already installed # on this system, so they can be located by find_package(). # Find Protobuf installation # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib -Wl,--no-as-needed") if(ENABLE_COVERAGE) add_compile_options(-coverage) add_link_options(-lgcov --coverage) endif() # Proto file # Generated sources file(GLOB_RECURSE PROTO_FILE_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./proto/*.proto) ms_grpc_generate(PROTO_SRC_LIST PROTO_HDR_LIST ${PROTO_FILE_LIST}) add_library(PROTO_SRC_LIB STATIC ${PROTO_SRC_LIST}) target_compile_options(PROTO_SRC_LIB PRIVATE "-Wno-array-bounds") include_directories("${CMAKE_BINARY_DIR}/mindspore_serving" ${CMAKE_BINARY_DIR}) # for proto header file include_directories("ccsrc") # serving_common for c++ server and python interface file(GLOB_RECURSE SERVING_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ccsrc/master/*.cc" "ccsrc/common/*.cc" "ccsrc/worker/*.cc") file(GLOB_RECURSE SERVING_ASCEND_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ccsrc/worker/inference/mindspore_model_wrap.cc") list(REMOVE_ITEM SERVING_SRC ${SERVING_ASCEND_SRC}) add_library(serving_common SHARED ${SERVING_SRC}) add_library(serving_ascend SHARED ${SERVING_ASCEND_SRC}) target_link_libraries(serving_ascend PRIVATE serving_common) target_link_libraries(serving_ascend PRIVATE ${SECUREC_LIBRARY}) include(CheckPIESupported) check_pie_supported() set_property(TARGET serving_common PROPERTY POSITION_INDEPENDENT_CODE TRUE) set_property(TARGET serving_ascend PROPERTY POSITION_INDEPENDENT_CODE TRUE) target_link_libraries(serving_common PRIVATE PROTO_SRC_LIB) target_link_libraries(serving_common PRIVATE mindspore_serving::ssl mindspore_serving::crypto) target_link_libraries(serving_common PRIVATE mindspore_serving::grpc++) target_link_libraries(serving_common PRIVATE mindspore_serving::protobuf pthread rt) target_link_libraries(serving_common PRIVATE mindspore_serving::event mindspore_serving::event_pthreads) target_link_libraries(serving_common PRIVATE mindspore_serving::event_core) target_link_libraries(serving_common PRIVATE mindspore_serving::event_openssl) target_link_libraries(serving_common PRIVATE mindspore_serving::glog) target_link_libraries(serving_common PRIVATE mindspore_serving::eigen) target_link_libraries(serving_common PRIVATE ${SECUREC_LIBRARY}) set_target_properties(serving_common PROPERTIES SKIP_BUILD_RPATH TRUE) # python add_compile_definitions(ENABLE_PYTHON) file(GLOB_RECURSE PY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ccsrc/python/*.cc") find_package(Python3 3.7 COMPONENTS Interpreter Development) if(Python3_FOUND) set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}") set(PYTHON_LIBRARIES "${Python3_LIBRARIES}") else() find_python_package(py_inc py_lib) set(PYTHON_INCLUDE_DIRS "${py_inc}") set(PYTHON_LIBRARIES "${py_lib}") endif() include_directories(${PYTHON_INCLUDE_DIRS}) pybind11_add_module(_mindspore_serving NO_EXTRAS ${PY_SRC_LIST}) set_target_properties(_mindspore_serving PROPERTIES LINK_FLAGS_RELEASE -s) target_link_libraries(_mindspore_serving PRIVATE "${PYTHON_LIBRARIES}") target_include_directories(_mindspore_serving PRIVATE ${pybind11_INCLUDE_DIRS}) target_link_libraries(_mindspore_serving PRIVATE serving_common) set_property(TARGET _mindspore_serving PROPERTY POSITION_INDEPENDENT_CODE TRUE) target_link_options(serving_common PRIVATE -Wl,-init,mindspore_serving_log_init) # user set path if(ENABLE_TESTCASES) include_directories(${CMAKE_SOURCE_DIR}/tests/ut/stub) target_link_libraries(serving_ascend PRIVATE mindspore) elseif(MS_WHL_LIB_PATH) include_directories(${MS_WHL_LIB_PATH}/../) elseif(MS_BACKEND_HEADER) include_directories(${CMAKE_SOURCE_DIR}/third_party/mindspore) include_directories(${CMAKE_SOURCE_DIR}/third_party/mindspore/mindspore/core) else() message(FATAL_ERROR "Please check MindSpore path.") endif() ================================================ FILE: mindspore_serving/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore Serving.""" ================================================ FILE: mindspore_serving/ccsrc/common/buffer_tensor.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/buffer_tensor.h" namespace mindspore::serving { BufferTensor::BufferTensor(DataType type, const std::vector &shape, uint8_t *data, size_t data_len, bool data_readonly) { type_ = type; shape_ = shape; data_ = data; data_len_ = data_len; data_readonly_ = data_readonly; } BufferTensor::~BufferTensor() { data_ = nullptr; } std::vector BufferTensor::shape() const { return shape_; } void BufferTensor::set_shape(const std::vector &shape) { shape_ = shape; } DataType BufferTensor::data_type() const { return type_; } void BufferTensor::set_data_type(DataType type) { type_ = type; } const uint8_t *BufferTensor::data() const { return data_; } size_t BufferTensor::data_size() const { return data_len_; } bool BufferTensor::resize_data(size_t data_len) { if (data_len != data_len_) { MSI_LOG_EXCEPTION << "Buffer tensor cannot resize data"; } return true; } uint8_t *BufferTensor::mutable_data() { if (data_readonly_) { MSI_LOG_EXCEPTION << "Buffer tensor is create readonly"; } return data_; } size_t BufferTensor::bytes_data_size() const { if (!is_bytes_val_data()) { return 0; } return 1; } void BufferTensor::get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const { MSI_EXCEPTION_IF_NULL(data); MSI_EXCEPTION_IF_NULL(bytes_len); if (!is_bytes_val_data()) { MSI_LOG_EXCEPTION << "Buffer tensor data type is not kMSI_Bytes or kMSI_String, cannot get bytes data"; } *data = data_; *bytes_len = data_len_; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/buffer_tensor.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_BUFFER_TENSOR_H #define MINDSPORE_SERVING_BUFFER_TENSOR_H #include #include "common/serving_common.h" namespace mindspore::serving { class BufferTensor : public TensorBase { public: // the data's lifetime must longer than this object BufferTensor(DataType type, const std::vector &shape, uint8_t *data, size_t data_len, bool data_readonly); ~BufferTensor(); // For all data type std::vector shape() const override; void set_shape(const std::vector &shape) override; DataType data_type() const override; void set_data_type(DataType type) override; // All the following interfaces are not for kMSI_String and kMSI_Bytes const uint8_t *data() const override; size_t data_size() const override; bool resize_data(size_t data_len) override; uint8_t *mutable_data() override; // For kMSI_String and kMSI_Bytes void clear_bytes_data() override { MSI_LOG_EXCEPTION << "Buffer tensor cannot clear bytes data"; } void add_bytes_data(const uint8_t *, size_t) override { MSI_LOG_EXCEPTION << "Buffer tensor cannot add bytes data"; } size_t bytes_data_size() const override; void get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const override; private: uint8_t *data_ = nullptr; size_t data_len_ = 0; std::vector shape_; DataType type_; bool data_readonly_ = false; }; class BufferTensorWithOwner : public BufferTensor { public: BufferTensorWithOwner(const TensorBasePtr &buffer_tensor_owner, DataType type, const std::vector &shape, uint8_t *data, size_t data_len, bool data_readonly) : BufferTensor(type, shape, data, data_len, data_readonly), buffer_tensor_owner_(buffer_tensor_owner) {} ~BufferTensorWithOwner() = default; private: TensorBasePtr buffer_tensor_owner_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_BUFFER_TENSOR_H ================================================ FILE: mindspore_serving/ccsrc/common/exit_handle.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/exit_handle.h" #include #include namespace mindspore { namespace serving { ExitSignalHandle &ExitSignalHandle::Instance() { static ExitSignalHandle instance = ExitSignalHandle(); return instance; } void ExitSignalHandle::InitSignalHandle() { if (!has_inited_.test_and_set()) { (void)signal(SIGINT, HandleSignal); (void)signal(SIGTERM, HandleSignal); } } // waiting ctrl+c or stop message to exit, // if no server is running or server has exited, there is no need to wait void ExitSignalHandle::MasterWait() { if (!is_running_) { MSI_LOG_INFO << "Exit Handle has not started or has exited"; return; } auto exit_future = master_exit_requested_.get_future(); exit_future.wait(); MSI_LOG_WARNING << "Receive exit signal " << exit_signal_; } // waiting ctrl+c or stop message to exit, // if no server is running or server has exited, there is no need to wait void ExitSignalHandle::WorkerWait() { if (!is_running_) { MSI_LOG_INFO << "Exit Handle has not started or has exited"; return; } auto exit_future = worker_exit_requested_.get_future(); exit_future.wait(); MSI_LOG_WARNING << "Receive exit signal " << exit_signal_; } // waiting ctrl+c or stop message to exit, // if no server is running or server has exited, there is no need to wait void ExitSignalHandle::AgentWait() { if (!is_running_) { MSI_LOG_INFO << "Exit Handle has not started or has exited"; return; } auto exit_future = agent_exit_requested_.get_future(); exit_future.wait(); MSI_LOG_WARNING << "Receive exit signal " << exit_signal_; } void ExitSignalHandle::Start() { if (is_running_) { return; } is_running_ = true; master_exit_requested_ = std::promise(); worker_exit_requested_ = std::promise(); agent_exit_requested_ = std::promise(); has_exited_.clear(); InitSignalHandle(); } void ExitSignalHandle::Stop() { HandleSignal(0); } bool ExitSignalHandle::HasStopped() const { return !is_running_; } void ExitSignalHandle::HandleSignal(int sig) { auto &instance = Instance(); instance.HandleSignalInner(sig); } void ExitSignalHandle::HandleSignalInner(int sig) { if (!has_exited_.test_and_set()) { exit_signal_ = sig; master_exit_requested_.set_value(); worker_exit_requested_.set_value(); agent_exit_requested_.set_value(); is_running_ = false; } } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/common/exit_handle.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_EXIT_HANDLE_H #define MINDSPORE_SERVING_EXIT_HANDLE_H #include #include #include #include "common/serving_common.h" namespace mindspore { namespace serving { // Handle Ctrl+C signal. When the master or worker is waiting for the Ctrl+C signal, // it can continue to perform subsequent operations, such as cleaning. class MS_API ExitSignalHandle { public: static ExitSignalHandle &Instance(); void InitSignalHandle(); void MasterWait(); void WorkerWait(); void AgentWait(); void Start(); void Stop(); bool HasStopped() const; private: std::promise master_exit_requested_; std::promise worker_exit_requested_; std::promise agent_exit_requested_; std::atomic_flag has_exited_ = true; std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; std::atomic_bool is_running_ = false; int exit_signal_ = 0; static void HandleSignal(int sig); void HandleSignalInner(int sig); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_EXIT_HANDLE_H ================================================ FILE: mindspore_serving/ccsrc/common/float16.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_COMMON_FLOAT16_H_ #define MINDSPORE_SERVING_COMMON_FLOAT16_H_ #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64) // Built for lite and ARM #include using float16 = float16_t; inline float half_to_float(float16 h) { return static_cast(h); } #else #include #include "Eigen/Core" using float16 = Eigen::half; using HalfToFloat = std::function; const inline HalfToFloat half_to_float = Eigen::half_impl::half_to_float; #endif #endif // MINDSPORE_SERVING_COMMON_FLOAT16_H_ ================================================ FILE: mindspore_serving/ccsrc/common/grpc_async_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_GRPC_ASYNC_SERVER_H #define MINDSPORE_SERVING_GRPC_ASYNC_SERVER_H #include #include #include #include #include #include #include #include "common/serving_common.h" #include "common/ssl_config.h" #include "common/utils.h" namespace mindspore::serving { class GrpcAsyncServiceContextBase { public: GrpcAsyncServiceContextBase() = default; virtual ~GrpcAsyncServiceContextBase() = default; virtual void NewAndHandleRequest() = 0; bool HasFinish() const { return finished_; } void SetFinish() { finished_ = true; } private: bool finished_ = false; }; template class GrpcAsyncServiceContext : public GrpcAsyncServiceContextBase { public: GrpcAsyncServiceContext(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq) : service_impl_(service_impl), async_service_(async_service), cq_(cq) {} ~GrpcAsyncServiceContext() = default; GrpcAsyncServiceContext() = delete; virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; static void EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq) { auto call = new Derived(service_impl, async_service, cq); call->StartEnqueueRequest(); } void NewAndHandleRequest() final { EnqueueRequest(service_impl_, async_service_, cq_); HandleRequest(); } protected: grpc::ServerContext ctx_; ServiceImpl *service_impl_; AsyncService *async_service_; grpc::ServerCompletionQueue *cq_; }; template class GrpcAsyncServer { public: GrpcAsyncServer() {} virtual ~GrpcAsyncServer() { Stop(); } virtual void EnqueueRequests() = 0; Status Start(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size, const std::string &server_tag) { if (in_running_) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: " << server_tag << " server is already running"; } grpc::ServerBuilder builder; if (max_msg_mb_size > 0) { constexpr uint32_t mbytes_to_bytes = 1u << 20; builder.SetMaxSendMessageSize(static_cast(max_msg_mb_size * mbytes_to_bytes)); builder.SetMaxReceiveMessageSize(static_cast(max_msg_mb_size * mbytes_to_bytes)); } builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); int port_tcpip = 0; auto creds = BuildServerCredentialsFromSSLConfigFile(ssl_config); Status status; status = CheckServerAddress(socket_address, server_tag); if (status != SUCCESS) { return status; } builder.AddListeningPort(socket_address, creds, &port_tcpip); status = RegisterService(&builder); if (status != SUCCESS) return status; cq_ = builder.AddCompletionQueue(); server_ = builder.BuildAndStart(); if (!server_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: " << server_tag << " server start failed, create server failed, address " << socket_address; } auto grpc_server_run = [this]() { HandleRequests(); }; grpc_thread_ = std::thread(grpc_server_run); in_running_ = true; MSI_LOG(INFO) << server_tag << " server start success, listening on " << socket_address; std::cout << "Serving: " << server_tag << " server start success, listening on " << socket_address << std::endl; return SUCCESS; } Status CheckServerAddress(const std::string &address, const std::string &server_tag) { Status status; std::string prefix = "unix:"; if (address.substr(0, prefix.size()) == prefix) { if (address.size() > prefix.size()) { return SUCCESS; } else { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: Empty grpc server unix domain socket address"; return status; } } status = common::CheckAddress(address, server_tag, nullptr, nullptr); if (status != SUCCESS) { return status; } return SUCCESS; } std::shared_ptr BuildServerCredentialsFromSSLConfigFile(const SSLConfig &ssl_config) { if (!ssl_config.use_ssl) { return grpc::InsecureServerCredentials(); } grpc::SslServerCredentialsOptions ssl_ops(ssl_config.verify_client ? GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY : GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE); if (!ssl_config.custom_ca.empty()) { ssl_ops.pem_root_certs = ssl_config.custom_ca; } grpc::SslServerCredentialsOptions::PemKeyCertPair keycert = {ssl_config.private_key, ssl_config.certificate}; ssl_ops.pem_key_cert_pairs.push_back(keycert); return grpc::SslServerCredentials(ssl_ops); } Status HandleRequests() { void *tag; bool ok = false; EnqueueRequests(); while (cq_->Next(&tag, &ok)) { ProcessRequest(tag, ok); } return SUCCESS; } void Stop() { if (in_running_) { if (server_) { server_->Shutdown(); } // Always shutdown the completion queue after the server. if (cq_) { cq_->Shutdown(); } grpc_thread_.join(); } in_running_ = false; } Status RegisterService(grpc::ServerBuilder *builder) { builder->RegisterService(&svc_); return SUCCESS; } void ProcessRequest(void *tag, bool rpc_ok) { auto rq = static_cast(tag); if (rq->HasFinish() || !rpc_ok) { // !rpc_ok: cancel get request when shutting down. delete rq; } else { rq->NewAndHandleRequest(); rq->SetFinish(); // will delete next time } } protected: std::unique_ptr cq_; std::unique_ptr server_; AsyncService svc_; bool in_running_ = false; std::thread grpc_thread_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_GRPC_ASYNC_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/common/grpc_client.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/grpc_client.h" namespace mindspore { namespace serving { std::unique_ptr client_; std::unique_ptr distributed_client_; } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/common/grpc_client.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H #define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H #include #include #include #include #include #include #include #include #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "proto/ms_master.pb.h" #include "proto/ms_master.grpc.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { using PredictOnFinish = std::function; using AsyncPredictCallback = std::function; template class MSServiceClient { public: MSServiceClient() = default; ~MSServiceClient() { if (in_running_) { cq_.Shutdown(); if (client_thread_.joinable()) { try { client_thread_.join(); } catch (const std::system_error &) { } catch (...) { } } } in_running_ = false; } void Start() { client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); in_running_ = true; } void AsyncCompleteRpc() { void *got_tag; bool ok = false; while (cq_.Next(&got_tag, &ok)) { AsyncClientCall *call = static_cast(got_tag); if (call->status.ok()) { call->callback(SUCCESS); } else { MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message() << ", target address: " << call->target_address; call->callback(Status(WORKER_UNAVAILABLE, call->status.error_message())); } delete call; } } void PredictAsync(const Request &request, Reply *reply, MSStub *stub, const AsyncPredictCallback &callback, const std::string &target_address) { AsyncClientCall *call = new AsyncClientCall; call->reply = reply; call->callback = callback; call->target_address = target_address; call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); call->response_reader->StartCall(); call->response_reader->Finish(call->reply, &call->status, call); } private: struct AsyncClientCall { grpc::ClientContext context; grpc::Status status; Reply *reply; std::string target_address; AsyncPredictCallback callback; std::shared_ptr> response_reader; }; grpc::CompletionQueue cq_; std::thread client_thread_; bool in_running_ = false; }; using MSPredictClient = MSServiceClient; using MSDistributedClient = MSServiceClient; extern std::unique_ptr client_; extern std::unique_ptr distributed_client_; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H ================================================ FILE: mindspore_serving/ccsrc/common/grpc_server.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/grpc_server.h" namespace mindspore::serving { Status GrpcServer::Start(const std::shared_ptr &service, const std::string &server_address, int max_msg_mb_size, const std::string &server_tag) { service_ = service; if (in_running_) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: " << server_tag << " server is already running"; } // Set the port is not reuseable auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); grpc::ServerBuilder serverBuilder; (void)serverBuilder.SetOption(std::move(option)); if (max_msg_mb_size > 0) { constexpr int mbytes_to_bytes = static_cast(1u << 20); (void)serverBuilder.SetMaxSendMessageSize(max_msg_mb_size * mbytes_to_bytes); (void)serverBuilder.SetMaxReceiveMessageSize(max_msg_mb_size * mbytes_to_bytes); } (void)serverBuilder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); (void)serverBuilder.RegisterService(service.get()); server_ = serverBuilder.BuildAndStart(); if (server_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: " << server_tag << " server start failed, create server failed, address " << server_address; } auto grpc_server_run = [this, server_address, server_tag]() { MSI_LOG(INFO) << server_tag << " server start success, listening on " << server_address; server_->Wait(); }; grpc_thread_ = std::thread(grpc_server_run); in_running_ = true; return SUCCESS; } void GrpcServer::Stop() { if (in_running_) { server_->Shutdown(); grpc_thread_.join(); server_ = nullptr; } in_running_ = false; } std::shared_ptr GrpcServer::CreateChannel(const std::string &target_str) { grpc::ChannelArguments channel_args; constexpr int mbytes_to_bytes = static_cast(1u << 20); channel_args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, gRpcMaxMBMsgSize * mbytes_to_bytes); std::shared_ptr channel = grpc::CreateCustomChannel(target_str, grpc::InsecureChannelCredentials(), channel_args); return channel; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/grpc_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_GRPC_SERVER_H #define MINDSPORE_SERVING_GRPC_SERVER_H #include #include #include #include #include #include #include #include "common/serving_common.h" namespace mindspore::serving { constexpr int gRpcDefaultMsgMBSize = 100; constexpr int gRpcMaxMBMsgSize = 512; // max 512 MB class GrpcServer { public: GrpcServer() = default; ~GrpcServer() noexcept { Stop(); } Status Start(const std::shared_ptr &service, const std::string &server_address, int max_msg_size, const std::string &server_tag); void Stop(); static std::shared_ptr CreateChannel(const std::string &target_str); private: std::unique_ptr server_; std::thread grpc_thread_; bool in_running_ = false; std::shared_ptr service_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_GRPC_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/common/heart_beat.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/heart_beat.h" namespace mindspore::serving {} // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/heart_beat.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_HEART_BEAT_H #define MINDSPORE_SERVING_HEART_BEAT_H #include #include #include #include #include #include #include #include #include #include #include #include #include "common/serving_common.h" #include "common/grpc_server.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" namespace mindspore::serving { using TimerCallback = std::function; class Timer { public: Timer() {} ~Timer() { is_stoped_.store(true); cv_.notify_all(); if (thread_.joinable()) { try { thread_.join(); } catch (const std::system_error &) { } catch (...) { } } } void StartTimer(int64_t millisecond, TimerCallback callback) { auto timer_run = [this, millisecond, callback]() { while (!is_stoped_.load()) { std::unique_lock lk(cv_m_); if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) { callback(); } } }; thread_ = std::thread(timer_run); } void StopTimer() { is_stoped_.store(true); cv_.notify_all(); } private: std::mutex cv_m_; std::thread thread_; std::condition_variable cv_; std::atomic is_stoped_ = false; }; template class Watcher { public: explicit Watcher(const std::string host_address) { host_address_ = host_address; } ~Watcher() { if (ping_running_) { ping_cq_.Shutdown(); if (ping_thread_.joinable()) { try { ping_thread_.join(); } catch (const std::system_error &) { } catch (...) { } } } ping_running_ = false; if (pong_running_) { pong_cq_.Shutdown(); if (pong_thread_.joinable()) { try { pong_thread_.join(); } catch (const std::system_error &) { } catch (...) { } } } pong_running_ = false; } void StartWatch(const std::string &address) { if (ping_running_ == false) { ping_thread_ = std::thread(&Watcher::AsyncPingRpc, this); ping_running_ = true; } auto it = watchee_map_.find(address); if (it != watchee_map_.end()) { MSI_LOG(INFO) << "watchee exist: " << address; it->second.timeouts_ = 0; it->second.timer_ = std::make_shared(); // add timer it->second.timer_->StartTimer(max_time_out_ / max_ping_times_, std::bind(&Watcher::RecvPongTimeOut, this, address)); } else { WatcheeContext context; auto channel = GrpcServer::CreateChannel(address); context.stub_ = SendStub::NewStub(channel); context.timer_ = std::make_shared(); // add timer context.timer_->StartTimer(max_time_out_ / max_ping_times_, std::bind(&Watcher::RecvPongTimeOut, this, address)); watchee_map_.insert(make_pair(address, context)); } MSI_LOG(INFO) << "Begin to send ping to " << address; SendPing(address); } void StopWatch(const std::string &address) { // clear map and timer auto it = watchee_map_.find(address); if (it == watchee_map_.end()) { MSI_LOG(INFO) << "watchee not exist: " << address; return; } it->second.timer_->StopTimer(); watchee_map_.erase(address); } void SendPing(const std::string &address) { auto it = watchee_map_.find(address); if (it == watchee_map_.end()) { MSI_LOG(INFO) << "watchee not exist: " << address; return; } it->second.timeouts_ += 1; // send async message PingAsync(address); } void RecvPing(const std::string &address) { std::unique_lock lock{m_lock_}; if (pong_running_ == false) { pong_thread_ = std::thread(&Watcher::AsyncPongRpc, this); pong_running_ = true; } // recv message auto it = watcher_map_.find(address); if (it != watcher_map_.end()) { it->second.timer_->StopTimer(); it->second.timer_ = std::make_shared(); // add timer it->second.timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address)); } else { WatcherContext context; auto channel = GrpcServer::CreateChannel(address); context.stub_ = RecvStub::NewStub(channel); context.timer_ = std::make_shared(); // add timer context.timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address)); watcher_map_.insert(make_pair(address, context)); MSI_LOG(INFO) << "Begin to send pong to " << address; } // send async message PongAsync(address); } void RecvPong(const std::string &address) { std::unique_lock lock{m_lock_}; // recv message auto it = watchee_map_.find(address); if (it != watchee_map_.end()) { it->second.timeouts_ = 0; } else { MSI_LOG(INFO) << "Recv Pong after timeout or stop"; } } void RecvPongTimeOut(const std::string &address) { std::unique_lock lock{m_lock_}; auto it = watchee_map_.find(address); if (it != watchee_map_.end()) { if (it->second.timeouts_ >= max_ping_times_) { // add exit handle MSI_LOG(ERROR) << "Recv Pong Time Out from " << address << ", host address is " << host_address_; it->second.timer_->StopTimer(); // need erase map return; } SendPing(address); } else { MSI_LOG(INFO) << "Recv Pong Time Out after timeout or stop"; } } void RecvPingTimeOut(const std::string &address) { std::unique_lock lock{m_lock_}; auto it = watcher_map_.find(address); if (it != watcher_map_.end()) { MSI_LOG(ERROR) << "Recv Ping Time Out from " << address << ", host address is " << host_address_; // add exit handle it->second.timer_->StopTimer(); // need erase map } else { MSI_LOG(INFO) << "Recv Ping Time Out after timeout or stop"; } } void PingAsync(const std::string &address) { auto it = watchee_map_.find(address); if (it != watchee_map_.end()) { proto::PingRequest request; request.set_address(host_address_); AsyncPingCall *call = new AsyncPingCall; call->response_reader = it->second.stub_->PrepareAsyncPing(&call->context, request, &ping_cq_); call->response_reader->StartCall(); call->response_reader->Finish(&call->reply, &call->status, call); } } void PongAsync(const std::string &address) { auto it = watcher_map_.find(address); if (it != watcher_map_.end()) { proto::PongRequest request; request.set_address(host_address_); AsyncPongCall *call = new AsyncPongCall; call->response_reader = it->second.stub_->PrepareAsyncPong(&call->context, request, &pong_cq_); call->response_reader->StartCall(); call->response_reader->Finish(&call->reply, &call->status, call); } } void AsyncPingRpc() { void *got_tag; bool ok = false; while (ping_cq_.Next(&got_tag, &ok)) { AsyncPingCall *call = static_cast(got_tag); if (!call->status.ok()) { MSI_LOG_DEBUG << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); } delete call; } } void AsyncPongRpc() { void *got_tag; bool ok = false; while (pong_cq_.Next(&got_tag, &ok)) { AsyncPongCall *call = static_cast(got_tag); if (!call->status.ok()) { MSI_LOG_DEBUG << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); } delete call; } } private: struct WatcheeContext { uint64_t timeouts_ = 0; std::shared_ptr timer_ = nullptr; std::shared_ptr stub_ = nullptr; }; struct WatcherContext { uint64_t timeouts_ = 0; std::shared_ptr timer_ = nullptr; std::shared_ptr stub_ = nullptr; }; struct AsyncPingCall { grpc::ClientContext context; grpc::Status status; proto::PingReply reply; std::shared_ptr> response_reader; }; struct AsyncPongCall { grpc::ClientContext context; grpc::Status status; proto::PongReply reply; std::shared_ptr> response_reader; }; std::string host_address_; uint64_t max_ping_times_ = 20; uint64_t max_time_out_ = 20000; // 20s std::unordered_map watchee_map_; std::unordered_map watcher_map_; std::mutex m_lock_; grpc::CompletionQueue ping_cq_; std::thread ping_thread_; bool ping_running_ = false; grpc::CompletionQueue pong_cq_; std::thread pong_thread_; bool pong_running_ = false; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_HEART_BEAT_H ================================================ FILE: mindspore_serving/ccsrc/common/instance.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_INSTANCE_H #define MINDSPORE_SERVING_INSTANCE_H #include #include #include "common/serving_common.h" #include "common/servable.h" #include "common/instance_data.h" namespace mindspore::serving { struct Instance { InstanceData data; // for inputs of function, predict, output const MethodSignature *method_def = nullptr; uint64_t stage_index = 0; uint64_t stage_max = 0; std::map stage_data_list; // input: 0, stage: 1-n uint64_t user_id = 0; Status error_msg = SUCCESS; }; using InstancePtr = std::shared_ptr; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_INSTANCE_H ================================================ FILE: mindspore_serving/ccsrc/common/instance_data.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_INSTANCE_DATA_H #define MINDSPORE_SERVING_INSTANCE_DATA_H #include #include "common/serving_common.h" namespace mindspore::serving { using InstanceData = std::vector; struct ResultInstance { InstanceData data; Status error_msg = SUCCESS; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_INSTANCE_DATA_H ================================================ FILE: mindspore_serving/ccsrc/common/log.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/log.h" #include #include #include #include #define google mindspore_serving_private #include "glog/logging.h" #include "common/utils.h" namespace mindspore { namespace serving { int g_ms_serving_log_level = static_cast(LOG_WARNING); static std::string GetTimeString() { #if defined(_WIN32) || defined(_WIN64) time_t time_seconds = time(0); struct tm now_time; localtime_s(&now_time, &time_seconds); constexpr int base_year = 1900; std::stringstream ss; ss << now_time.tm_year + base_year << "-" << now_time.tm_mon + 1 << "-" << now_time.tm_mday << " " << now_time.tm_hour << ":" << now_time.tm_min << ":" << now_time.tm_sec; return ss.str(); #else constexpr auto BUFLEN = 80; char buf[BUFLEN] = {'\0'}; struct timeval cur_time; (void)gettimeofday(&cur_time, nullptr); struct tm now; constexpr int width = 3; constexpr int64_t time_convert_unit = 1000; (void)localtime_r(&cur_time.tv_sec, &now); (void)strftime(buf, BUFLEN, "%Y-%m-%d-%H:%M:%S", &now); // format date and time std::stringstream ss; ss << "." << std::setfill('0') << std::setw(width) << cur_time.tv_usec / time_convert_unit << "." << std::setfill('0') << std::setw(width) << cur_time.tv_usec % time_convert_unit; return std::string(buf) + ss.str(); #endif } static std::string GetProcName() { #if defined(__APPLE__) || defined(__FreeBSD__) const std::string appname = getprogname(); #elif defined(_GNU_SOURCE) const std::string appname = program_invocation_name; #else const std::string appname = "?"; #endif // some times, the appname is an absolute path, its too long std::string app_name(appname); std::size_t pos = app_name.rfind("/"); if (pos == std::string::npos) { return app_name; } if (pos + 1 >= app_name.size()) { return app_name; } return app_name.substr(pos + 1); } static std::string GetLogLevel(MsLogLevel level) { switch (level) { case LOG_DEBUG: return "DEBUG"; case LOG_INFO: return "INFO"; case LOG_WARNING: return "WARNING"; case LOG_EXCEPTION: return "EXCEPTION"; case LOG_ERROR: default: return "ERROR"; } } // convert MsLogLevel to corresponding glog level static int GetGlogLevel(MsLogLevel level) { switch (level) { case LOG_DEBUG: case LOG_INFO: return google::GLOG_INFO; case LOG_WARNING: return google::GLOG_WARNING; case LOG_ERROR: case LOG_EXCEPTION: default: return google::GLOG_ERROR; } } // get threshold level static int GetThresholdLevel(const std::string &threshold) { if (threshold.empty()) { return google::GLOG_WARNING; } else if (threshold == "DEBUG" || threshold == "INFO") { return google::GLOG_INFO; } else if (threshold == "WARNING") { return google::GLOG_WARNING; } else if (threshold == "ERROR" || threshold == "CRITICAL") { return google::GLOG_ERROR; } else { return google::GLOG_WARNING; } } void LogWriter::OutputLog(const std::string &msg_str) const { if (static_cast(log_level_) < g_ms_serving_log_level) { return; } auto submodule_name = "SERVING"; google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex << std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " " << "[" << file_ << ":" << line_ << "] " << func_ << "] " << msg_str << std::endl; } static int GetGlobalLogLevel() { return FLAGS_v; } enum class LogConfigToken : size_t { INVALID, // indicate invalid token LEFT_BRACE, // '{' RIGHT_BRACE, // '}' VARIABLE, // '[A-Za-z][A-Za-z0-9_]*' NUMBER, // [0-9]+ COMMA, // ',' COLON, // ':' EOS, // End Of String, '\0' NUM_LOG_CFG_TOKENS }; static const char *g_tok_names[static_cast(LogConfigToken::NUM_LOG_CFG_TOKENS)] = { "invalid", // indicate invalid token "{", // '{' "}", // '}' "variable", // '[A-Za-z][A-Za-z0-9_]*' "number", // [0-9]+ ",", // ',' ":", // ':' "end-of-string", // End Of String, '\0' }; static inline bool IsAlpha(char ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); } static inline bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } class LogConfigLexer { public: explicit LogConfigLexer(const std::string &text) : buffer_(text) { cur_idx_ = 0; } ~LogConfigLexer() = default; // skip white space, and return the first char after white space char SkipWhiteSpace() { while (cur_idx_ < buffer_.size()) { char ch = buffer_[cur_idx_]; if (ch == ' ' || ch == '\t') { ++cur_idx_; continue; } return ch; } return '\0'; } LogConfigToken GetNext(std::string *const ptr) { char ch = SkipWhiteSpace(); // clang-format off static const std::map single_char_map = { {'{', LogConfigToken::LEFT_BRACE}, {'}', LogConfigToken::RIGHT_BRACE}, {',', LogConfigToken::COMMA}, {':', LogConfigToken::COLON}, {'\0', LogConfigToken::EOS}, }; // clang-format on auto iter = single_char_map.find(ch); if (iter != single_char_map.end()) { if (ptr != nullptr) { *ptr = std::string() + ch; } ++cur_idx_; return iter->second; } else if (IsAlpha(ch)) { std::ostringstream oss; do { oss << ch; ch = buffer_[++cur_idx_]; } while (cur_idx_ < buffer_.size() && (IsAlpha(ch) || IsDigit(ch) || ch == '_')); if (ptr != nullptr) { *ptr = std::string(oss.str()); } return LogConfigToken::VARIABLE; } else if (IsDigit(ch)) { std::ostringstream oss; do { oss << ch; ch = buffer_[++cur_idx_]; } while (cur_idx_ < buffer_.size() && IsDigit(ch)); if (ptr != nullptr) { *ptr = std::string(oss.str()); } return LogConfigToken::NUMBER; } return LogConfigToken::INVALID; } private: std::string buffer_; size_t cur_idx_; }; class LogConfigParser { public: explicit LogConfigParser(const std::string &cfg) : lexer(cfg) {} ~LogConfigParser() = default; bool Expect(LogConfigToken expected, LogConfigToken tok) const { if (expected != tok) { MSI_LOG(WARNING) << "Parse submodule log configuration text error, expect `" << g_tok_names[static_cast(expected)] << "`, but got `" << g_tok_names[static_cast(tok)] << "`. The whole configuration will be ignored."; return false; } return true; } // The text of config MS_SUBMODULE_LOG_v is in the form {submodule1:log_level1,submodule2:log_level2,...}. // Valid values of log levels are: 0 - debug, 1 - info, 2 - warning, 3 - error // e.g. MS_SUBMODULE_LOG_v={PARSER:0, ANALYZER:2, PIPELINE:1} std::map Parse() { std::map log_levels; bool flag_error = false; std::string text; auto tok = lexer.GetNext(&text); // empty string if (tok == LogConfigToken::EOS) { return log_levels; } if (!Expect(LogConfigToken::LEFT_BRACE, tok)) { return log_levels; } do { std::string key, val; tok = lexer.GetNext(&key); if (!Expect(LogConfigToken::VARIABLE, tok)) { flag_error = true; break; } tok = lexer.GetNext(&text); if (!Expect(LogConfigToken::COLON, tok)) { flag_error = true; break; } tok = lexer.GetNext(&val); if (!Expect(LogConfigToken::NUMBER, tok)) { flag_error = true; break; } log_levels[key] = val; tok = lexer.GetNext(&text); } while (tok == LogConfigToken::COMMA); if (!flag_error && !Expect(LogConfigToken::RIGHT_BRACE, tok)) { flag_error = true; } if (flag_error) { log_levels.clear(); } return log_levels; } private: LogConfigLexer lexer; }; bool ParseLogLevel(const std::string &str_level, MsLogLevel *ptr_level) { if (str_level.size() == 1) { int ch = str_level.c_str()[0]; ch = ch - '0'; // subtract ASCII code of '0', which is 48 if (ch >= static_cast(LOG_DEBUG) && ch <= static_cast(LOG_ERROR)) { if (ptr_level != nullptr) { *ptr_level = static_cast(ch); } return true; } } return false; } void InitSubModulesLogLevel() { // initialize submodule's log level using global auto global_log_level = GetGlobalLogLevel(); g_ms_serving_log_level = global_log_level; // set submodule's log level auto submodule = common::GetEnv("MS_SUBMODULE_LOG_v"); MSI_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; LogConfigParser parser(submodule); auto configs = parser.Parse(); for (const auto &cfg : configs) { if (cfg.first == "SERVING") { MsLogLevel submodule_log_level; if (!ParseLogLevel(cfg.second, &submodule_log_level)) { MSI_LOG(WARNING) << "Illegal log level value " << cfg.second << " for " << cfg.first << ", ignore it."; continue; } g_ms_serving_log_level = static_cast(submodule_log_level); } } } void common_log_init(void) { // do not use glog predefined log prefix FLAGS_log_prefix = false; // disable log buffer, real-time output FLAGS_logbufsecs = 0; // set default log level to WARNING if (common::GetEnv("GLOG_v").empty()) { FLAGS_v = static_cast(mindspore::serving::LOG_WARNING); } // set default log file mode to 0640 if (common::GetEnv("GLOG_logfile_mode").empty()) { FLAGS_logfile_mode = 0640; } std::string logtostderr = common::GetEnv("GLOG_logtostderr"); // default print log to screen if (logtostderr.empty()) { FLAGS_logtostderr = true; } else if (logtostderr == "0" && common::GetEnv("GLOG_log_dir").empty()) { FLAGS_logtostderr = true; MSI_LOG(WARNING) << "`GLOG_log_dir` is not set, output log to screen."; } // default GLOG_stderrthreshold level to WARNING auto threshold = common::GetEnv("GLOG_stderrthreshold"); FLAGS_stderrthreshold = GetThresholdLevel(threshold); mindspore::serving::InitSubModulesLogLevel(); } } // namespace serving } // namespace mindspore extern "C" { #if defined(_WIN32) || defined(_WIN64) __attribute__((constructor)) void mindspore_serving_log_init(void) { #else void mindspore_serving_log_init(void) { #endif static bool is_glog_inited = false; if (!is_glog_inited) { #if !defined(_WIN32) && !defined(_WIN64) google::InitGoogleLogging("mindspore_serving"); #endif is_glog_inited = true; } mindspore::serving::common_log_init(); } } #undef google ================================================ FILE: mindspore_serving/ccsrc/common/log.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_LOG_H #define MINDSPORE_SERVING_LOG_H #include #include #include #include #include #include #include #include namespace mindspore::serving { #define MS_API __attribute__((visibility("default"))) #define SERVING_LOG_HDR_FILE_REL_PATH "mindspore_serving/ccsrc/common/log.h" // Get start index of file relative path in __FILE__ static constexpr size_t GetRelPathPos() noexcept { return sizeof(__FILE__) > sizeof(SERVING_LOG_HDR_FILE_REL_PATH) ? sizeof(__FILE__) - sizeof(SERVING_LOG_HDR_FILE_REL_PATH) : 0; } #define SERVING_FILE_NAME \ (sizeof(__FILE__) > mindspore::serving::GetRelPathPos() \ ? static_cast(__FILE__) + mindspore::serving::GetRelPathPos() \ : static_cast(__FILE__)) class AsStringHelper { public: template static std::string AsString(const T &val) noexcept { std::stringstream ss; ss << val; return ss.str(); } static std::string AsString(const bool &val) noexcept { return val ? "true" : "false"; } template static std::string AsString(const std::vector &val) noexcept { std::stringstream ss; ss << "["; for (size_t i = 0; i < val.size(); i++) { ss << AsString(val[i]); if (i + 1 < val.size()) { ss << ", "; } } ss << "]"; return ss.str(); } template static std::string AsString(const std::unordered_map &val) noexcept { return AsStringMap(val); } template static std::string AsString(const std::map &val) noexcept { return AsStringMap(val); } template static std::string AsString(const std::vector> &val) noexcept { return AsStringMap(val); } private: template static std::string AsStringMap(const T &val) noexcept { std::stringstream ss; ss << "{"; size_t index = 0; for (auto &item : val) { ss << AsString(item.first) << ": " << AsString(item.second); if (index + 1 < val.size()) { ss << ", "; } index += 1; } ss << "}"; return ss.str(); } }; class LogStream { public: LogStream() { sstream_ = std::make_shared(); } ~LogStream() = default; template LogStream &operator<<(const T &val) noexcept { (*sstream_) << val; return *this; } LogStream &operator<<(const bool &val) noexcept { (*sstream_) << (val ? "true" : "false"); return *this; } template LogStream &operator<<(const std::vector &val) noexcept { (*sstream_) << "["; for (size_t i = 0; i < val.size(); i++) { (*this) << val[i]; if (i + 1 < val.size()) { (*sstream_) << ", "; } } (*sstream_) << "]"; return *this; } template LogStream &operator<<(const std::unordered_map &val) noexcept { return OutputMap(val); } template LogStream &operator<<(const std::map &val) noexcept { return OutputMap(val); } LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { (*sstream_) << func; return *this; } friend class LogWriter; friend class Status; private: std::shared_ptr sstream_; template LogStream &OutputMap(const T &val) noexcept { (*sstream_) << "{"; size_t index = 0; for (auto &item : val) { (*this) << item.first << ": " << item.second; if (index + 1 < val.size()) { (*sstream_) << ", "; } index += 1; } (*sstream_) << "}"; return *this; } }; enum MsLogLevel { LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_EXCEPTION, }; class MS_API LogWriter { public: LogWriter(const char *file, int line, const char *func, MsLogLevel log_level) : file_(file), line_(line), func_(func), log_level_(log_level) {} ~LogWriter() = default; std::string operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))) { std::ostringstream msg; msg << stream.sstream_->rdbuf(); auto msg_str = GetOutputMsg(msg); OutputLog(msg_str); return msg_str; } void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))) { std::ostringstream msg; msg << stream.sstream_->rdbuf(); auto msg_str = GetOutputMsg(msg); OutputLog(msg_str); throw std::runtime_error(msg_str); } std::string GetOutputMsg(const std::ostringstream &msg) const { std::string msg_str = msg.str(); constexpr int max_log_size = 384; constexpr int msg_log_start_size = 192; if (msg_str.length() > max_log_size) { msg_str = msg_str.substr(0, msg_log_start_size) + "..." + msg_str.substr(msg_str.length() - msg_log_start_size); } return msg_str; } private: void OutputLog(const std::string &msg_str) const; const char *file_; int line_; const char *func_; MsLogLevel log_level_; }; extern int g_ms_serving_log_level MS_API; #define MSILOG_IF(level, condition) \ !(condition) ? std::string() \ : mindspore::serving::LogWriter(SERVING_FILE_NAME, __LINE__, __FUNCTION__, \ mindspore::serving::LOG_##level) < mindspore::serving::LogStream() #define MSILOG_NOIF(level) \ mindspore::serving::LogWriter(SERVING_FILE_NAME, __LINE__, __FUNCTION__, mindspore::serving::LOG_##level) < \ mindspore::serving::LogStream() inline bool IS_OUTPUT_ON(enum MsLogLevel level) { return static_cast(level) >= g_ms_serving_log_level; } #define MSILOG_THROW \ mindspore::serving::LogWriter(SERVING_FILE_NAME, __LINE__, __FUNCTION__, mindspore::serving::LOG_EXCEPTION) ^ \ mindspore::serving::LogStream() #define MSI_LOG(level) MSI_LOG_##level #define MSI_LOG_DEBUG MSILOG_IF(DEBUG, mindspore::serving::IS_OUTPUT_ON(mindspore::serving::LOG_DEBUG)) #define MSI_LOG_INFO MSILOG_IF(INFO, mindspore::serving::IS_OUTPUT_ON(mindspore::serving::LOG_INFO)) #define MSI_LOG_WARNING MSILOG_IF(WARNING, mindspore::serving::IS_OUTPUT_ON(mindspore::serving::LOG_WARNING)) #define MSI_LOG_ERROR MSILOG_IF(ERROR, mindspore::serving::IS_OUTPUT_ON(mindspore::serving::LOG_ERROR)) #define MSI_LOG_EXCEPTION MSILOG_THROW #define MSI_EXCEPTION_IF_NULL(ptr) \ do { \ if ((ptr) == nullptr) { \ MSI_LOG_EXCEPTION << ": The pointer[" << #ptr << "] is null."; \ } \ } while (0) } // namespace mindspore::serving #endif // MINDSPORE_SERVING_LOG_H ================================================ FILE: mindspore_serving/ccsrc/common/proto_tensor.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/proto_tensor.h" #include #include #include #include #include #include "common/buffer_tensor.h" #include "common/servable.h" #include "master/dispacther.h" #include "common/shared_memory.h" using std::string; using std::unordered_map; using std::vector; namespace mindspore::serving { const size_t kMaxShapeElementCount = INT32_MAX; ProtoTensor::ProtoTensor(proto::Tensor *other) : tensor_(other) {} ProtoTensor::~ProtoTensor() {} DataType ProtoTensor::data_type() const { MSI_EXCEPTION_IF_NULL(tensor_); return TransDataType2Inference(tensor_->dtype()); } void ProtoTensor::set_data_type(DataType data_type) { MSI_EXCEPTION_IF_NULL(tensor_); tensor_->set_dtype(TransDataType2Proto(data_type)); } std::vector ProtoTensor::shape() const { MSI_EXCEPTION_IF_NULL(tensor_); std::vector result; auto dims = tensor_->shape().dims(); std::transform(dims.begin(), dims.end(), std::back_inserter(result), [](const int64_t dim) { return dim; }); return result; } void ProtoTensor::set_shape(const std::vector &shape) { MSI_EXCEPTION_IF_NULL(tensor_); auto tensor_shape = tensor_->mutable_shape(); tensor_shape->Clear(); size_t element_count = 1; for (auto dim : shape) { if (dim < 0 || (dim > 0 && element_count > kMaxShapeElementCount / dim)) { MSI_LOG_ERROR << "failed to set shape, invalid dim num " << dim; tensor_shape->Clear(); return; } element_count *= dim; tensor_shape->add_dims(dim); } } bool ProtoTensor::resize_data(size_t data_len) { MSI_EXCEPTION_IF_NULL(tensor_); if (tensor_->has_shm_data()) { if (data_len == tensor_->shm_data().data_size()) { return true; } MSI_LOG_EXCEPTION << "Cannot resize shared memory data size from " << tensor_->shm_data().data_size() << " to " << data_len; } string *buffer = tensor_->mutable_data(); if (buffer == nullptr) { MSI_LOG_ERROR << "invalid buffer data"; return false; } buffer->resize(data_len); return true; } size_t ProtoTensor::data_size() const { MSI_EXCEPTION_IF_NULL(tensor_); if (tensor_->has_shm_data()) { return tensor_->shm_data().data_size(); } return tensor_->data().size(); } uint8_t *ProtoTensor::mutable_data() { MSI_EXCEPTION_IF_NULL(tensor_); if (data_size() == 0) { return nullptr; } if (tensor_->has_shm_data()) { auto status = AttachSharedMemory(); if (status != SUCCESS) { return nullptr; } return shm_attach_.offset_address; } return reinterpret_cast(tensor_->mutable_data()->data()); } const uint8_t *ProtoTensor::data() const { MSI_EXCEPTION_IF_NULL(tensor_); if (data_size() == 0) { return nullptr; } if (tensor_->has_shm_data()) { auto status = AttachSharedMemory(); if (status != SUCCESS) { return nullptr; } return shm_attach_.offset_address; } return reinterpret_cast(tensor_->data().data()); } void ProtoTensor::clear_bytes_data() { MSI_EXCEPTION_IF_NULL(tensor_); return tensor_->mutable_bytes_val()->Clear(); } void ProtoTensor::add_bytes_data(const uint8_t *data, size_t bytes_len) { MSI_EXCEPTION_IF_NULL(tensor_); tensor_->add_bytes_val(data, bytes_len); } size_t ProtoTensor::bytes_data_size() const { MSI_EXCEPTION_IF_NULL(tensor_); return tensor_->bytes_val().size(); } void ProtoTensor::get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const { MSI_EXCEPTION_IF_NULL(data); MSI_EXCEPTION_IF_NULL(bytes_len); MSI_EXCEPTION_IF_NULL(tensor_); if (index >= static_cast(tensor_->bytes_val().size())) { MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_->bytes_val().size(); } auto &bytes = tensor_->bytes_val(index); *data = reinterpret_cast(bytes.data()); *bytes_len = bytes.size(); } proto::DataType ProtoTensor::TransDataType2Proto(DataType data_type) { const std::unordered_map id2type_map{ {serving::kMSI_Unknown, proto::MS_UNKNOWN}, {serving::kMSI_Bool, proto::MS_BOOL}, {serving::kMSI_Float64, proto::MS_FLOAT64}, {serving::kMSI_Int8, proto::MS_INT8}, {serving::kMSI_Uint8, proto::MS_UINT8}, {serving::kMSI_Int16, proto::MS_INT16}, {serving::kMSI_Uint16, proto::MS_UINT16}, {serving::kMSI_Int32, proto::MS_INT32}, {serving::kMSI_Uint32, proto::MS_UINT32}, {serving::kMSI_Int64, proto::MS_INT64}, {serving::kMSI_Uint64, proto::MS_UINT64}, {serving::kMSI_Float16, proto::MS_FLOAT16}, {serving::kMSI_Float32, proto::MS_FLOAT32}, {serving::kMSI_String, proto::MS_STRING}, {serving::kMSI_Bytes, proto::MS_BYTES}, }; auto it = id2type_map.find(data_type); if (it == id2type_map.end()) { MSI_LOG_WARNING << "failed to set data type, undefined data type " << data_type; return proto::MS_UNKNOWN; } else { return it->second; } } DataType ProtoTensor::TransDataType2Inference(proto::DataType data_type) { const std::unordered_map type2id_map{ {proto::MS_UNKNOWN, kMSI_Unknown}, {proto::MS_BOOL, kMSI_Bool}, {proto::MS_INT8, kMSI_Int8}, {proto::MS_UINT8, kMSI_Uint8}, {proto::MS_INT16, kMSI_Int16}, {proto::MS_UINT16, kMSI_Uint16}, {proto::MS_INT32, kMSI_Int32}, {proto::MS_UINT32, kMSI_Uint32}, {proto::MS_INT64, kMSI_Int64}, {proto::MS_UINT64, kMSI_Uint64}, {proto::MS_FLOAT16, kMSI_Float16}, {proto::MS_FLOAT32, kMSI_Float32}, {proto::MS_FLOAT64, kMSI_Float64}, {proto::MS_STRING, kMSI_String}, {proto::MS_BYTES, kMSI_Bytes}, }; auto it = type2id_map.find(data_type); if (it == type2id_map.end()) { MSI_LOG_WARNING << "failed to get data type, undefined data type " << data_type; return kMSI_Unknown; } else { return it->second; } } void ProtoTensor::SetSharedMemory(const proto::ShmTensorData &shm_data) { *tensor_->mutable_shm_data() = shm_data; } Status ProtoTensor::AttachSharedMemory() const { if (has_attached_shm_) { return SUCCESS; } if (tensor_ == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The proto tensor object cannot be nullptr"; } if (!tensor_->has_shm_data()) { return SUCCESS; } const proto::ShmTensorData &shm_data = tensor_->shm_data(); auto status = SharedMemoryManager::Instance().Attach(shm_data.memory_key(), shm_data.bytes_size(), shm_data.data_offset(), shm_data.data_size(), &shm_attach_); if (status != SUCCESS) { MSI_LOG_ERROR << "Attach shared memory failed, memory key: " << shm_data.memory_key() << ", bytes size: " << shm_data.bytes_size() << ", data offset: " << shm_data.data_offset() << ", data size: " << shm_data.data_size(); return status; } has_attached_shm_ = true; return SUCCESS; } void GrpcTensorHelper::GetRequestSpec(const proto::PredictRequest &request, RequestSpec *request_spec) { MSI_EXCEPTION_IF_NULL(request_spec); request_spec->servable_name = request.servable_spec().name(); request_spec->method_name = request.servable_spec().method_name(); request_spec->version_number = request.servable_spec().version_number(); } void GrpcTensorHelper::ConvertProtoWorkerSpec(const proto::RegisterRequest &proto_request, WorkerRegSpec *worker_spec) { MSI_EXCEPTION_IF_NULL(worker_spec); auto &proto_worker_spec = proto_request.worker_spec(); worker_spec->worker_address = proto_worker_spec.address(); worker_spec->worker_pid = proto_worker_spec.worker_pid(); auto &proto_spec = proto_worker_spec.servable_spec(); auto &servable_spec = worker_spec->servable_spec; servable_spec.servable_name = proto_spec.name(); servable_spec.version_number = proto_spec.version_number(); servable_spec.batch_size = proto_spec.batch_size(); servable_spec.own_device = proto_spec.own_device(); for (const auto &proto_method : proto_spec.methods()) { ServableMethodInfo method_info; method_info.name = proto_method.name(); method_info.only_model_stage = proto_method.only_model_stage(); for (auto &name : proto_method.input_names()) { method_info.input_names.push_back(name); } servable_spec.methods.push_back(method_info); } ConvertProtoModelInfos(proto_spec.model_infos(), &servable_spec.models); } void GrpcTensorHelper::ConvertWorkerSpec(const WorkerRegSpec &worker_spec, proto::RegisterRequest *proto_request) { auto proto_worker_spec = proto_request->mutable_worker_spec(); proto_worker_spec->set_address(worker_spec.worker_address); proto_worker_spec->set_worker_pid(worker_spec.worker_pid); auto proto_spec = proto_worker_spec->mutable_servable_spec(); const auto &spec = worker_spec.servable_spec; proto_spec->set_name(spec.servable_name); proto_spec->set_version_number(spec.version_number); proto_spec->set_batch_size(spec.batch_size); proto_spec->set_own_device(spec.own_device); for (auto &method : spec.methods) { auto proto_method = proto_spec->add_methods(); proto_method->set_name(method.name); proto_method->set_only_model_stage(method.only_model_stage); for (auto &name : method.input_names) { proto_method->add_input_names(name); } } ConvertModelInfos(spec.models, proto_spec->mutable_model_infos()); } void GrpcTensorHelper::ConvertProtoModelInfos(const proto::ModelInfos &proto_model_infos, std::map *model_infos) { MSI_EXCEPTION_IF_NULL(model_infos); model_infos->clear(); auto convert_tensor_info = [](const proto::TensorInfo &proto_tensor_info) -> TensorInfo { TensorInfo tensor_info; tensor_info.is_no_batch_dim = proto_tensor_info.is_no_batch_dim(); tensor_info.size = proto_tensor_info.size(); tensor_info.data_type = ProtoTensor::TransDataType2Inference(proto_tensor_info.dtype()); auto &proto_shape = proto_tensor_info.shape().dims(); std::copy(proto_shape.begin(), proto_shape.end(), std::back_inserter(tensor_info.shape)); return tensor_info; }; for (const auto &proto_model_it : proto_model_infos.model_infos()) { auto &model_key = proto_model_it.first; auto &proto_model = proto_model_it.second; ModelInfo &model_info = (*model_infos)[model_key]; model_info.batch_size = proto_model.batch_size(); for (auto &proto_subgraph : proto_model.subgraph_infos()) { ModelSubgraphInfo subgraph_info; for (auto &input_tensor : proto_subgraph.inputs()) { subgraph_info.input_infos.push_back(convert_tensor_info(input_tensor)); } for (auto &output_tensor : proto_subgraph.outputs()) { subgraph_info.output_infos.push_back(convert_tensor_info(output_tensor)); } model_info.sub_graph_infos.push_back(subgraph_info); } } } void GrpcTensorHelper::ConvertModelInfos(const std::map &model_infos, proto::ModelInfos *proto_model_infos) { MSI_EXCEPTION_IF_NULL(proto_model_infos); proto_model_infos->Clear(); auto convert_tensor_info = [](const TensorInfo &tensor_info, proto::TensorInfo *proto_tensor_info) { proto_tensor_info->set_is_no_batch_dim(tensor_info.is_no_batch_dim); proto_tensor_info->set_size(tensor_info.size); proto_tensor_info->set_dtype(ProtoTensor::TransDataType2Proto(tensor_info.data_type)); auto proto_shape = proto_tensor_info->mutable_shape()->mutable_dims(); for (auto &dim : tensor_info.shape) { proto_shape->Add(dim); } }; auto &proto_models_items = *(proto_model_infos->mutable_model_infos()); for (const auto &model_it : model_infos) { auto &model_key = model_it.first; auto &model_info = model_it.second; auto &proto_model = proto_models_items[model_key]; proto_model.set_batch_size(model_info.batch_size); for (auto &subgraph_info : model_info.sub_graph_infos) { auto proto_subgraph = proto_model.add_subgraph_infos(); for (auto &input_tensor : subgraph_info.input_infos) { convert_tensor_info(input_tensor, proto_subgraph->add_inputs()); } for (auto &output_tensor : subgraph_info.output_infos) { convert_tensor_info(output_tensor, proto_subgraph->add_outputs()); } } } } Status GrpcTensorHelper::CreateInstanceFromRequest(const MethodSignature &method, const proto::PredictRequest &request, vector *results) { MSI_EXCEPTION_IF_NULL(results); results->clear(); Status status; if (request.instances_size() == 0) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Instances count of request cannot be 0, servable: " << method.servable_name << ", method: " << method.method_name; } status = CreateInstanceFromRequestInstances(request, method, results); if (status != SUCCESS) { MSI_LOG_ERROR << "Create instances from request instances failed"; return status; } return SUCCESS; } void GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request, const MethodSignature &method, const vector &instances, proto::PredictReply *reply) { auto status = CreateReplyFromInstancesInner(request, method, instances, reply); if (status != SUCCESS) { CreateReplyFromErrorMsg(status, reply); } } Status GrpcTensorHelper::CreateInstanceFromPredictReply(const RequestSpec &request_spec, const proto::PredictReply &reply, std::vector *error, std::vector *results) { MSI_EXCEPTION_IF_NULL(error); MSI_EXCEPTION_IF_NULL(results); results->clear(); error->clear(); if (reply.instances_size() == 0 && reply.error_msg_size() == 0) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The instance or error count of reply cannot be 0, servable: " << request_spec.servable_name << ", method: " << request_spec.method_name; } std::copy(reply.error_msg().begin(), reply.error_msg().end(), std::back_inserter(*error)); for (auto &item : reply.instances()) { // cppcheck-suppress useStlAlgorithm results->push_back(&item); } return SUCCESS; } Status GrpcTensorHelper::CreatePredictReplyFromInstances(const proto::PredictRequest &request, const std::vector &errors, const std::vector &instances, proto::PredictReply *reply) { MSI_EXCEPTION_IF_NULL(reply); for (auto &instance : instances) { auto proto_instance = reply->add_instances(); if (instance) { *proto_instance->mutable_items() = instance->items(); } } bool all_ok = true; bool all_same = true; for (auto &error : errors) { if (error.error_code() != 0) { all_ok = false; } if (error.error_code() != errors[0].error_code() || error.error_msg() != errors[0].error_msg()) { all_same = false; } } if (!all_ok) { if (all_same) { reply->clear_instances(); auto proto_error = reply->add_error_msg(); proto_error->set_error_msg(errors[0].error_msg()); proto_error->set_error_code(errors[0].error_code()); } else { for (auto &error : errors) { auto proto_error = reply->add_error_msg(); proto_error->set_error_msg(error.error_msg()); proto_error->set_error_code(error.error_code()); } } } return SUCCESS; } Status GrpcTensorHelper::CreatePredictRequestFromInstances(const RequestSpec &request_spec, const std::vector &instances, proto::PredictRequest *request) { MSI_EXCEPTION_IF_NULL(request); auto proto_spec = request->mutable_servable_spec(); proto_spec->set_name(request_spec.servable_name); proto_spec->set_method_name(request_spec.method_name); proto_spec->set_version_number(request_spec.version_number); for (auto &instance : instances) { auto proto_instance = request->add_instances(); *proto_instance = *instance; } return SUCCESS; } Status GrpcTensorHelper::CreateReplyFromInstancesInner(const proto::PredictRequest &request, const MethodSignature &method, const std::vector &instances, proto::PredictReply *reply) { MSI_EXCEPTION_IF_NULL(reply); if (instances.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Result instances count invalid, cannot be 0"; } if (instances.size() != static_cast(request.instances_size())) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Result instances number " << instances.size() << " is inconsistent with request instances number " << request.instances_size(); } Status status; size_t err_cnt = 0; for (auto &instance : instances) { if (instance->error_msg != SUCCESS) { err_cnt++; } else if (instance->data.size() != method.outputs.size()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Result data tensor size " << instance->data.size() << " not equal outputs size " << method.outputs.size() << " defined in method signature"; } } if (err_cnt > 0) { for (auto &instance : instances) { auto proto_err_msg = reply->add_error_msg(); proto_err_msg->set_error_code(instance->error_msg.StatusCode()); if (instance->error_msg == INVALID_INPUTS) { proto_err_msg->set_error_msg(instance->error_msg.StatusMessage()); } else if (instance->error_msg != SUCCESS) { proto_err_msg->set_error_msg(instance->error_msg.StatusMessage()); } } } // create instance reply, same with request for (size_t index = 0; index < instances.size(); index++) { auto proto_instance = reply->add_instances(); auto &instance = instances[index]; if (instance->data.empty()) { continue; } auto &request_output_buffers = request.instances(index).output_buffers(); auto proto_items = proto_instance->mutable_items(); for (size_t i = 0; i < method.outputs.size(); i++) { auto &output_tensor = instance->data[i]; auto &output_name = method.outputs[i]; auto &proto_tensor = (*proto_items)[method.outputs[i]]; ProtoTensor result_tensor(&proto_tensor); auto it = request_output_buffers.find(output_name); if (it != request_output_buffers.end()) { if (output_tensor->is_bytes_val_data()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "The output shared memory cannot be specified in the request" << " when the data type of output " << output_name << " is " << output_tensor->data_type() << ", output name: " << output_name; } auto &shm_data = it->second; if (shm_data.data_size() != output_tensor->data_size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The data size " << shm_data.data_size() << " of output shared memory " << " is inconsistent with the data size " << output_tensor->data_size() << " of result, output name: " << output_name; } result_tensor.SetSharedMemory(shm_data); status = result_tensor.AttachSharedMemory(); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Attach output shared memory failed, memory key: " << shm_data.memory_key() << ", bytes size: " << shm_data.bytes_size() << ", data offset: " << shm_data.data_offset() << ", data size: " << shm_data.data_size() << ", output name: " << output_name; } } result_tensor.assign(*output_tensor); } } return SUCCESS; } Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::PredictRequest &request, const MethodSignature &method, std::vector *results) { MSI_EXCEPTION_IF_NULL(results); auto servable_name = request.servable_spec().name(); auto method_name = request.servable_spec().method_name(); Status status; auto &input_names = method.inputs; auto &output_names = method.outputs; for (auto &proto_instance : request.instances()) { InstanceData instance_data; for (const auto &input_name : input_names) { auto it = proto_instance.items().find(input_name); if (it == proto_instance.items().end()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Cannot find input " << input_name << " in instance input , servable " << servable_name << ", method " << method_name; } auto &tensor_proto = it->second; status = CheckRequestTensor(tensor_proto); if (status != SUCCESS) { auto status2 = INFER_STATUS(INVALID_INPUTS) << "Instances input " << input_name << " check failed"; MSI_LOG_ERROR << status2.StatusMessage(); return Status(INVALID_INPUTS, status2.StatusMessage() + ", detail: " + status.StatusMessage()); } auto add_tensor = std::make_shared(const_cast(&tensor_proto)); if (tensor_proto.has_shm_data()) { status = add_tensor->AttachSharedMemory(); if (status != SUCCESS) { auto &shm_data = tensor_proto.shm_data(); MSI_LOG_ERROR << "Attach input shared memory failed, memory key: " << shm_data.memory_key() << ", bytes size: " << shm_data.bytes_size() << ", data offset: " << shm_data.data_offset() << ", data size: " << shm_data.data_size() << ", input name: " << input_name; return status; } } instance_data.push_back(add_tensor); } auto &output_buffers = proto_instance.output_buffers(); if (!output_buffers.empty()) { for (auto &buffer : output_buffers) { auto it = std::find(output_names.begin(), output_names.end(), buffer.first); if (it == output_names.end()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "The name " << buffer.first << " of the output buffers cannot be found in the output names " << output_names << " of the method, servable " << servable_name << ", method " << method_name; } auto &shm_data = buffer.second; SharedMemoryAttachItem item; status = SharedMemoryManager::Instance().Attach(shm_data.memory_key(), shm_data.bytes_size(), shm_data.data_offset(), shm_data.data_size(), &item); if (status != SUCCESS) { MSI_LOG_ERROR << "Attach output shared memory failed, memory key: " << shm_data.memory_key() << ", bytes size: " << shm_data.bytes_size() << ", data offset: " << shm_data.data_offset() << ", data size: " << shm_data.data_size() << ", output name: " << buffer.first; return status; } } } results->push_back(instance_data); } return SUCCESS; } Status GrpcTensorHelper::CheckRequestInstances(const proto::PredictRequest &request, const std::vector &input_names) { auto servable_name = request.servable_spec().name(); auto method_name = request.servable_spec().method_name(); Status status; for (auto &proto_instance : request.instances()) { for (const auto &input_name : input_names) { auto it = proto_instance.items().find(input_name); if (it == proto_instance.items().end()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Cannot find input " << input_name << " in instance input , servable " << servable_name << ", method " << method_name; } status = CheckRequestTensor(it->second); if (status != SUCCESS) { auto status2 = INFER_STATUS(INVALID_INPUTS) << "Instances input " << input_name << " check failed"; MSI_LOG_ERROR << status2.StatusMessage(); return Status(INVALID_INPUTS, status2.StatusMessage() + ", detail: " + status.StatusMessage()); } } } return SUCCESS; } void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) { worker_specs->rank_id = specs.rank_id(); worker_specs->batch_size = specs.batch_size(); for (auto &in : specs.inputs()) { TensorInfo info; info.data_type = ProtoTensor::TransDataType2Inference(in.dtype()); info.size = in.size(); info.is_no_batch_dim = in.is_no_batch_dim(); for (auto &dim : in.shape().dims()) { info.shape.push_back(dim); } worker_specs->input_infos.push_back(info); } for (auto &out : specs.outputs()) { TensorInfo info; info.data_type = ProtoTensor::TransDataType2Inference(out.dtype()); info.size = out.size(); info.is_no_batch_dim = out.is_no_batch_dim(); for (auto &dim : out.shape().dims()) { info.shape.push_back(dim); } worker_specs->output_infos.push_back(info); } } void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vector &worker_specs, proto::AgentRegisterRequest *request) { for (size_t i = 0; i < worker_specs.size(); i++) { auto &spec = worker_specs[i]; auto worker_spec = request->add_agent_spec(); worker_spec->set_rank_id(spec.rank_id); worker_spec->set_batch_size(spec.batch_size); for (auto &method : spec.input_infos) { auto proto_method = worker_spec->add_inputs(); proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); proto_method->set_size(method.size); proto_method->set_is_no_batch_dim(method.is_no_batch_dim); auto proto_shape = proto_method->mutable_shape(); for (auto &dim : method.shape) { proto_shape->add_dims(dim); } } for (auto &method : spec.output_infos) { auto proto_method = worker_spec->add_outputs(); proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); proto_method->set_size(method.size); proto_method->set_is_no_batch_dim(method.is_no_batch_dim); auto proto_shape = proto_method->mutable_shape(); for (auto &dim : method.shape) { proto_shape->add_dims(dim); } } } } Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { Status status; ProtoTensor tensor_input(const_cast(&tensor)); auto shape = tensor_input.shape(); if (tensor.dtype() == proto::MS_BYTES || tensor.dtype() == proto::MS_STRING) { if (tensor.bytes_val_size() != 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Instance tensor check failed: bytes or string type shape batch size can only be 1"; } if (!(shape.size() == 1 && shape[0] == 1) && !shape.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Instance tensor check failed: bytes or string type input " << " shape can only be (1,) or empty, but given shape is " << shape; } } else { bool zero_dim = false; for (auto &shape_item : tensor.shape().dims()) { if (shape_item < 0 || zero_dim) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Tensor check failed: input " << " shape " << shape << " invalid"; } if (shape_item == 0) { zero_dim = true; } } auto item_size = tensor_input.itemsize(); if (item_size == 0) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Tensor check failed: input data type " << tensor.dtype() << " invalid"; } size_t element_num = tensor_input.element_cnt(); auto expect_data_size = element_num * item_size; if (tensor.tensor_data_case() == proto::Tensor::TensorDataCase::kShmData) { if (expect_data_size != tensor.shm_data().data_size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Tensor check failed: input shared memory data size " << tensor.shm_data().data_size() << " not equal to expected size " << expect_data_size; } } else { if (expect_data_size != tensor.data().size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Tensor check failed: input data size " << tensor.data().size() << " invalid"; } } } return SUCCESS; } void GrpcTensorHelper::CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply) { MSI_EXCEPTION_IF_NULL(reply); if (error_msg == SUCCESS) { return; } reply->clear_error_msg(); reply->clear_instances(); auto proto_error_msg = reply->add_error_msg(); proto_error_msg->set_error_code(error_msg.StatusCode()); std::string error_msg_str = error_msg.StatusMessage(); if (error_msg_str.empty()) { proto_error_msg->set_error_msg("Predict failed"); } else { proto_error_msg->set_error_msg(error_msg_str); } } serving::LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type) { const std::map type_name_map{ {proto::MS_UNKNOWN, "proto::MS_UNKNOWN"}, {proto::MS_BOOL, "proto::kMSI_Bool"}, {proto::MS_INT8, "proto::MS_INT8"}, {proto::MS_UINT8, "proto::MS_UINT8"}, {proto::MS_INT16, "proto::MS_INT16"}, {proto::MS_UINT16, "proto::MS_UINT16"}, {proto::MS_INT32, "proto::MS_INT32"}, {proto::MS_UINT32, "proto::MS_UINT32"}, {proto::MS_INT64, "proto::MS_INT64"}, {proto::MS_UINT64, "proto::MS_UINT64"}, {proto::MS_FLOAT16, "proto::MS_FLOAT16"}, {proto::MS_FLOAT32, "proto::MS_FLOAT32"}, {proto::MS_FLOAT64, "proto::MS_FLOAT64"}, {proto::MS_STRING, "proto::MS_STRING"}, {proto::MS_BYTES, "proto::MS_BYTES"}, }; auto it = type_name_map.find(data_type); if (it != type_name_map.end()) { stream << it->second; } else { stream << "proto::MS_UNKNOWN"; } return stream; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/proto_tensor.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_PROTO_TENSOR_H_ #define MINDSPORE_SERVING_PROTO_TENSOR_H_ #include #include #include #include #include #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "proto/ms_master.pb.h" #include "proto/ms_distributed.pb.h" #include "common/instance.h" #include "common/servable.h" #include "common/shared_memory.h" namespace mindspore::serving { class ProtoTensor : public TensorBase { public: // the other's lifetime must longer than this object explicit ProtoTensor(proto::Tensor *other); ~ProtoTensor(); DataType data_type() const override; void set_data_type(DataType type) override; std::vector shape() const override; void set_shape(const std::vector &shape) override; const uint8_t *data() const override; size_t data_size() const override; bool resize_data(size_t data_len) override; uint8_t *mutable_data() override; void clear_bytes_data() override; void add_bytes_data(const uint8_t *data, size_t bytes_len) override; size_t bytes_data_size() const override; void get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const override; static proto::DataType TransDataType2Proto(DataType data_type); static DataType TransDataType2Inference(proto::DataType data_type); void SetSharedMemory(const proto::ShmTensorData &shm_data_proto); Status AttachSharedMemory() const; private: // if tensor_ is reference from other ms_serving::Tensor, the other's lifetime must // longer than this object proto::Tensor *tensor_; mutable bool has_attached_shm_ = false; mutable SharedMemoryAttachItem shm_attach_; }; class MS_API GrpcTensorHelper { public: static void GetRequestSpec(const proto::PredictRequest &request, RequestSpec *request_spec); static void ConvertProtoWorkerSpec(const proto::RegisterRequest &proto_request, WorkerRegSpec *worker_spec); static void ConvertWorkerSpec(const WorkerRegSpec &worker_spec, proto::RegisterRequest *proto_request); static void ConvertProtoModelInfos(const proto::ModelInfos &proto_model_infos, std::map *model_infos); static void ConvertModelInfos(const std::map &model_infos, proto::ModelInfos *proto_model_infos); static Status CreateInstanceFromRequest(const MethodSignature &method, const proto::PredictRequest &request, std::vector *results); static void CreateReplyFromInstances(const proto::PredictRequest &request, const MethodSignature &method, const std::vector &instances, proto::PredictReply *reply); static void CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply); static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); static void CopyFromWorkerAgentSpec(const std::vector &worker_specs, proto::AgentRegisterRequest *request); static Status CreatePredictRequestFromInstances(const RequestSpec &request_spec, const std::vector &instances, proto::PredictRequest *request); static Status CreatePredictReplyFromInstances(const proto::PredictRequest &request, const std::vector &errors, const std::vector &instances, proto::PredictReply *reply); static Status CreateInstanceFromPredictReply(const RequestSpec &request_spec, const proto::PredictReply &reply, std::vector *error, std::vector *results); static Status CheckRequestInstances(const proto::PredictRequest &request, const std::vector &input_names); private: static Status CreateInstanceFromRequestInstances(const proto::PredictRequest &request, const MethodSignature &method, std::vector *results); static Status CheckRequestTensor(const proto::Tensor &tensor); static Status CreateReplyFromInstancesInner(const proto::PredictRequest &request, const MethodSignature &method, const std::vector &instances, proto::PredictReply *reply); }; extern MS_API LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type); } // namespace mindspore::serving #endif // MINDSPORE_SERVING_PROTO_TENSOR_H_ ================================================ FILE: mindspore_serving/ccsrc/common/servable.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/servable.h" #include #include #include "worker/stage_function.h" namespace mindspore::serving { void LocalModelMeta::SetModelFormat(const std::string &format) { if (format == "om") { model_format = kOM; } else if (format == "mindir") { model_format = kMindIR; } else if (format == "mindir_opt" || format == "mindir_lite") { model_format = kMindIR_Lite; } else { MSI_LOG_ERROR << "Invalid model format " << format; } } std::string ServableLoadSpec::Repr() const { std::string version; if (version_number > 0) { version = " version(" + std::to_string(version_number) + ") "; } return "servable(" + servable_name + ") " + version; } std::string WorkerRegSpec::Repr() const { std::stringstream str_stream; str_stream << "{worker_pid:" << worker_pid << ", address:" + worker_address << ", servable:" << servable_spec.servable_name + ", version:" << servable_spec.version_number << "}"; return str_stream.str(); } std::string RequestSpec::Repr() const { std::string version; if (version_number > 0) { version = " version(" + std::to_string(version_number) + ") "; } return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version; } void MethodSignature::AddStageFunction(const std::string &func_name, const std::vector> &stage_inputs, uint64_t batch_size, const std::string &tag) { MethodStage stage; stage.method_name = method_name; stage.stage_index = stage_index; stage.stage_key = func_name; if (PyStageFunctionStorage::Instance()->HasPyFunction(func_name)) { stage.stage_type = kMethodStageTypePyFunction; } else { auto func = CppStageFunctionStorage::Instance().GetFunction(func_name); if (!func) { MSI_LOG_EXCEPTION << "Function '" << func_name << "' is not defined"; } stage.stage_type = kMethodStageTypeCppFunction; } stage.stage_inputs = stage_inputs; stage.batch_size = batch_size; if (tag.empty()) { stage.tag = "Function '" + func_name + "'"; } else { stage.tag = tag; } stage_map[stage_index] = stage; stage_index += 1; } void MethodSignature::AddStageModel(const std::string &model_key, const std::vector> &stage_inputs, uint64_t subgraph, const std::string &tag) { MethodStage stage; stage.method_name = method_name; stage.stage_index = stage_index; stage.stage_key = model_key; stage.stage_type = kMethodStageTypeModel; stage.stage_inputs = stage_inputs; stage.subgraph = subgraph; if (tag.empty()) { stage.tag = "Model '" + model_key + "'"; } else { stage.tag = tag; } stage_map[stage_index] = stage; stage_index += 1; } void MethodSignature::SetReturn(const std::vector> &return_inputs) { MethodStage stage; stage.method_name = method_name; stage.stage_index = stage_index; stage.stage_key = "return"; stage.stage_type = kMethodStageTypeReturn; stage.stage_inputs = return_inputs; stage_map[stage_index] = stage; } size_t MethodSignature::GetStageMax() const { return stage_index; } const MethodSignature *ServableSignature::GetMethodDeclare(const std::string &method_name) const { auto item = find_if(methods.begin(), methods.end(), [&](const MethodSignature &v) { return v.method_name == method_name; }); if (item == methods.end()) { return nullptr; } return &(*item); } const ModelMeta *ServableSignature::GetModelDeclare(const std::string &model_key) const { auto item = find_if(model_metas.begin(), model_metas.end(), [&](const ModelMeta &v) { return v.common_meta.model_key == model_key; }); if (item == model_metas.end()) { return nullptr; } return &(*item); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/servable.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SERVABLE_H #define MINDSPORE_SERVING_SERVABLE_H #include #include #include #include #include #include #include "common/serving_common.h" #include "worker/inference/inference.h" namespace mindspore::serving { enum MethodStageType { kMethodStageTypeNone = 0, kMethodStageTypePyFunction, kMethodStageTypeCppFunction, kMethodStageTypeModel, kMethodStageTypeReturn, }; struct MethodStage { std::string method_name; uint64_t stage_index = 0; std::string stage_key; // function name, model name std::string tag; MethodStageType stage_type; uint64_t subgraph = 0; // when model std::vector> stage_inputs; // first: input- 0, stage- 1~n, second: output index // will be updated when model loaded uint64_t batch_size = 0; }; static const uint64_t kStageStartIndex = 1; struct MS_API MethodSignature { std::string servable_name; std::string method_name; std::vector inputs; std::vector outputs; std::map stage_map; // stage_index, MethodStage void AddStageFunction(const std::string &func_name, const std::vector> &stage_inputs, uint64_t batch_size = 0, const std::string &tag = ""); void AddStageModel(const std::string &model_key, const std::vector> &stage_inputs, uint64_t subgraph = 0, const std::string &tag = ""); void SetReturn(const std::vector> &return_inputs); // the max stage is return, when reach max stage, all stage works done size_t GetStageMax() const; private: // stage index begin with 1, 0 reserve for input, include function, model, return stage size_t stage_index = kStageStartIndex; }; struct ServableLoadSpec { std::string servable_directory; std::string servable_name; uint64_t version_number = 0; std::string Repr() const; }; struct ServableMethodInfo { std::string name; std::vector input_names; bool only_model_stage = false; }; struct ModelSubgraphInfo { std::vector input_infos; std::vector output_infos; }; struct ModelInfo { std::vector sub_graph_infos; uint64_t batch_size = 0; }; struct ServableRegSpec { std::string servable_name; uint64_t version_number = 0; uint64_t batch_size = 0; bool own_device = true; std::vector methods; std::map models; }; struct WorkerRegSpec { uint64_t worker_pid = 0; std::string worker_address; ServableRegSpec servable_spec; std::string Repr() const; }; struct RequestSpec { std::string servable_name; std::string method_name; uint64_t version_number = 0; // not specified std::string Repr() const; }; enum ServableType { kServableTypeUnknown = 0, kServableTypeLocal = 1, kServableTypeDistributed = 2, }; struct CommonModelMeta { std::string servable_name; // used to identify model, for local model: ";".join(model_files), for distributed model: servable name std::string model_key; bool with_batch_dim = true; // whether there is batch dim in model's inputs/outputs std::vector without_batch_dim_inputs; std::map inputs_count; std::map outputs_count; }; struct MS_API LocalModelMeta { std::vector model_files; // file names ModelType model_format = ModelType::kUnknownType; // OM, MindIR, MindIR_Lite ModelContext model_context; std::string config_file; void SetModelFormat(const std::string &format); }; struct DistributedModelMeta { size_t rank_size = 0; size_t stage_size = 0; bool enable_pipeline_infer = false; }; struct MS_API ModelMeta { CommonModelMeta common_meta; LocalModelMeta local_meta; DistributedModelMeta distributed_meta; }; struct MS_API ServableSignature { ServableType servable_type = kServableTypeUnknown; std::string servable_name; std::vector model_metas; std::vector methods; const MethodSignature *GetMethodDeclare(const std::string &method_name) const; const ModelMeta *GetModelDeclare(const std::string &model_key) const; }; struct WorkerAgentSpec { std::string agent_address; uint32_t rank_id = 0; std::vector input_infos; std::vector output_infos; uint32_t batch_size = 0; uint64_t subgraph = 0; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_SERVABLE_H ================================================ FILE: mindspore_serving/ccsrc/common/serving_common.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SERVING_COMMON_H #define MINDSPORE_SERVING_SERVING_COMMON_H #include #include "common/status.h" #include "common/log.h" #include "common/tensor.h" #include "common/utils.h" #endif // MINDSPORE_SERVING_SERVING_COMMON_H ================================================ FILE: mindspore_serving/ccsrc/common/shared_memory.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include "common/shared_memory.h" namespace mindspore { namespace serving { SharedMemoryAllocator &SharedMemoryAllocator::Instance() { static SharedMemoryAllocator instance = SharedMemoryAllocator(); return instance; } SharedMemoryAllocator::SharedMemoryAllocator() = default; SharedMemoryAllocator::~SharedMemoryAllocator() noexcept { std::unique_lock lock(lock_); for (auto &item : memory_map_) { auto &group = item.second; for (auto &shm : group.shm_map) { auto ret = munmap(shm.second.address, shm.second.bytes_size); if (ret == -1) { MSI_LOG_ERROR << "Failed to munmap, memory key: " << shm.second.memory_key; } ret = shm_unlink(shm.second.memory_key.c_str()); if (ret == -1) { MSI_LOG_ERROR << "Failed to shm_unlink " << shm.second.memory_key << ", errno: " << errno; } } } memory_map_.clear(); } Status SharedMemoryAllocator::AddShmMemoryBuffer(SharedMemoryGroup *shm_group) { auto item_size = shm_group->item_size; auto item_count = shm_group->item_count; auto memory_key = shm_group->memory_key_prefix + "_" + std::to_string(shm_group->shm_map.size()); // maximum 4GB memory if (item_size == 0 || item_count == 0 || UINT32_MAX / item_size < item_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid item size or item count, item size: " << item_size << ", item count :" << item_count << ", memory key: " << memory_key; } constexpr uint32_t align_size = 8; auto align_item_size = (item_size + align_size - 1) / align_size * align_size; auto shm_fd = shm_open(memory_key.c_str(), O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); if (shm_fd == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to shm_open " << memory_key << " , errno: " << errno; } uint64_t memory_size = align_item_size * item_count; auto ret = ftruncate(shm_fd, static_cast(memory_size)); if (ret == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to ftruncate " << memory_key << ", errno: " << errno << ", memory size: " << memory_size; } auto address = mmap(nullptr, memory_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); if (address == MAP_FAILED) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to mmap " << memory_key << ", errno: " << errno << ", memory size: " << memory_size; } ret = close(shm_fd); if (ret == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to close " << memory_key << ", errno: " << errno; } SharedMemory &shm = shm_group->shm_map[memory_key]; shm.memory_key = memory_key; shm.address = reinterpret_cast(address); shm.bytes_size = memory_size; uint64_t offset = 0; for (uint64_t i = 0; i < item_count; i++) { (void)shm.free_queue.emplace(offset); offset += align_item_size; } shm_group->free_count += item_count; MSI_LOG_INFO << "New shared memory success, memory key: " << memory_key << ", bytes size: " << memory_size << ", item count: " << item_count; return SUCCESS; } Status SharedMemoryAllocator::NewMemoryBuffer(const std::string &memory_key_prefix, uint64_t item_size, uint64_t item_count) { std::unique_lock lock(lock_); if (memory_map_.find(memory_key_prefix) != memory_map_.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Shared memory has already been inited"; } auto &group = memory_map_[memory_key_prefix]; group.memory_key_prefix = memory_key_prefix; group.item_size = item_size; group.item_count = item_count; group.free_count = 0; auto status = AddShmMemoryBuffer(&group); if (status != SUCCESS) { MSI_LOG_ERROR << "Alloc shared memory failed, memory key prefix: " << memory_key_prefix; return status; } return SUCCESS; } Status SharedMemoryAllocator::AllocMemoryItem(const std::string &memory_key_prefix, SharedMemoryItem *shm_item) { std::unique_lock lock(lock_); auto it = memory_map_.find(memory_key_prefix); if (it == memory_map_.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find shared memory " << memory_key_prefix; } auto &group = it->second; if (group.free_count == 0) { auto status = AddShmMemoryBuffer(&group); if (status != SUCCESS) { MSI_LOG_ERROR << "Alloc shared memory failed, memory key prefix: " << memory_key_prefix; return SUCCESS; } } for (auto &item : group.shm_map) { auto &shm = item.second; if (!shm.free_queue.empty()) { shm_item->memory_key_prefix = memory_key_prefix; shm_item->memory_key = shm.memory_key; shm_item->bytes_size = shm.bytes_size; shm_item->offset = *shm.free_queue.begin(); shm_item->offset_address = shm.address + shm_item->offset; shm_item->size = group.item_size; (void)shm.free_queue.erase(shm_item->offset); group.free_count -= 1; return SUCCESS; } } MSI_LOG_EXCEPTION << "There is no free shared memory"; } void SharedMemoryAllocator::ReleaseMemoryItem(const SharedMemoryItem &shm_item) { std::unique_lock lock(lock_); auto it = memory_map_.find(shm_item.memory_key_prefix); if (it == memory_map_.end()) { MSI_LOG_WARNING << "Cannot find shared memory prefix " << shm_item.memory_key_prefix; return; } auto shm_it = it->second.shm_map.find(shm_item.memory_key); if (shm_it == it->second.shm_map.end()) { MSI_LOG_WARNING << "Cannot find shared memory " << shm_item.memory_key; return; } if (shm_it->second.free_queue.count(shm_item.offset) > 0) { MSI_LOG_EXCEPTION << "Shared memory " << shm_item.memory_key << " has already been in free set, offset: " << shm_item.offset; } (void)shm_it->second.free_queue.emplace(shm_item.offset); it->second.free_count += 1; } ShmTensor::ShmTensor(DataType type, const std::vector &shape, const SharedMemoryItem &shm_item) : BufferTensor(type, shape, shm_item.offset_address, shm_item.size, false), shm_info_(shm_item) {} ShmTensor::~ShmTensor() noexcept { SharedMemoryAllocator::Instance().ReleaseMemoryItem(shm_info_); } SharedMemoryManager &SharedMemoryManager::Instance() { static SharedMemoryManager instance = SharedMemoryManager(); return instance; } SharedMemoryManager::SharedMemoryManager() {} SharedMemoryManager::~SharedMemoryManager() noexcept { std::unique_lock lock(lock_); for (auto &item : attached_shm_list_) { auto ret = munmap(item.address, item.bytes_size); if (ret == -1) { MSI_LOG_ERROR << "Failed to munmap, memory key: " << item.memory_key; } } attached_shm_list_.clear(); } Status SharedMemoryManager::Attach(const std::string &memory_key, uint64_t bytes_size, uint64_t data_offset, uint64_t data_size, SharedMemoryAttachItem *shm_info) { if (data_size > bytes_size || data_offset > bytes_size - data_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid memory size info, memory key: " << memory_key << ", bytes size: " << bytes_size << ", data offset: " << data_offset << ", data size: " << data_size; } SharedMemoryAttach attach_mem; auto status = Attach(memory_key, bytes_size, &attach_mem); if (status != SUCCESS) { return status; } shm_info->memory_key = attach_mem.memory_key; shm_info->offset_address = attach_mem.address + data_offset; shm_info->offset = data_offset; shm_info->size = data_size; return SUCCESS; } Status SharedMemoryManager::Detach(const std::string &memory_key) { std::unique_lock lock(lock_); auto it = std::find_if(attached_shm_list_.begin(), attached_shm_list_.end(), [&memory_key](const SharedMemoryAttach &item) { return memory_key == item.memory_key; }); if (it == attached_shm_list_.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find shared memory " << memory_key; } auto ret = munmap(it->address, it->bytes_size); if (ret == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to munmap, memory key: " << memory_key; } (void)attached_shm_list_.erase(it); return SUCCESS; } Status SharedMemoryManager::Attach(const std::string &memory_key, uint64_t bytes_size, SharedMemoryAttach *attach_mem) { std::unique_lock lock(lock_); for (auto &item : attached_shm_list_) { if (item.memory_key == memory_key) { *attach_mem = item; return SUCCESS; } } auto shm_fd = shm_open(memory_key.c_str(), O_RDWR, S_IRUSR | S_IWUSR); if (shm_fd == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to shm_open " << memory_key << " , errno: " << errno; } auto address = mmap(nullptr, bytes_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); if (address == MAP_FAILED) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to mmap " << memory_key << ", errno: " << errno << ", memory size: " << bytes_size; } auto ret = close(shm_fd); if (ret == -1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to close " << memory_key << ", errno: " << errno; } attach_mem->memory_key = memory_key; attach_mem->bytes_size = bytes_size; attach_mem->address = static_cast(address); attached_shm_list_.push_back(*attach_mem); return SUCCESS; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/common/shared_memory.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SHARED_MEMORY_H #define MINDSPORE_SERVING_SHARED_MEMORY_H #include #include #include #include #include #include #include #include "common/serving_common.h" #include "common/buffer_tensor.h" namespace mindspore { namespace serving { struct SharedMemoryItem { std::string memory_key_prefix; std::string memory_key; // for shm_open uint64_t bytes_size = 0; // for shm_open uint8_t *offset_address = nullptr; uint64_t offset = 0; uint64_t size = 0; }; struct SharedMemory { std::string memory_key; uint64_t bytes_size = 0; uint8_t *address = nullptr; std::set free_queue; }; struct SharedMemoryGroup { std::map shm_map; std::string memory_key_prefix; uint64_t item_size = 0; uint64_t item_count = 0; uint64_t free_count = 0; }; class SharedMemoryAllocator { public: static SharedMemoryAllocator &Instance(); SharedMemoryAllocator(); ~SharedMemoryAllocator() noexcept; Status NewMemoryBuffer(const std::string &memory_key_prefix, uint64_t item_size, uint64_t init_item_count); Status AllocMemoryItem(const std::string &memory_key_prefix, SharedMemoryItem *shm_item); void ReleaseMemoryItem(const SharedMemoryItem &shm_item); private: std::map memory_map_; std::mutex lock_; Status AddShmMemoryBuffer(SharedMemoryGroup *shm_group); }; class ShmTensor : public BufferTensor { public: ShmTensor(DataType type, const std::vector &shape, const SharedMemoryItem &shm_item); ~ShmTensor() noexcept; private: SharedMemoryItem shm_info_; }; struct SharedMemoryAttach { std::string memory_key; uint64_t bytes_size = 0; uint8_t *address = nullptr; }; struct SharedMemoryAttachItem { std::string memory_key; // for shm_open uint8_t *offset_address = nullptr; uint64_t offset = 0; uint64_t size = 0; }; class SharedMemoryManager { public: static SharedMemoryManager &Instance(); SharedMemoryManager(); ~SharedMemoryManager() noexcept; Status Attach(const std::string &memory_key, uint64_t bytes_size, uint64_t data_offset, uint64_t data_size, SharedMemoryAttachItem *shm_info); Status Detach(const std::string &memory_key); private: std::vector attached_shm_list_; std::mutex lock_; Status Attach(const std::string &memory_key, uint64_t bytes_size, SharedMemoryAttach *attach_mem); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_SHARED_MEMORY_H ================================================ FILE: mindspore_serving/ccsrc/common/ssl_config.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SSL_CONFIG_H #define MINDSPORE_SERVING_SSL_CONFIG_H #include namespace mindspore::serving { struct SSLConfig { std::string certificate; std::string private_key; std::string custom_ca; bool verify_client{false}; bool use_ssl{false}; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_SSL_CONFIG_H ================================================ FILE: mindspore_serving/ccsrc/common/status.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_STATUS_H #define MINDSPORE_SERVING_STATUS_H #include #include #include #include "common/log.h" namespace mindspore::serving { enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS, SYSTEM_ERROR, WORKER_UNAVAILABLE, SERVABLE_UNAVAILABLE, }; class Status { public: Status() : status_code_(FAILED) {} Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) : status_code_(status_code), status_msg_(status_msg) {} bool IsSuccess() const { return status_code_ == SUCCESS; } enum StatusCode StatusCode() const { return status_code_; } std::string StatusMessage() const { return status_msg_; } bool operator==(const Status &other) const { return status_code_ == other.status_code_; } bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } operator bool() const = delete; Status &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) { status_msg_ = stream.sstream_->str(); return *this; } Status &operator=(const std::string &msg) noexcept __attribute__((visibility("default"))) { status_msg_ = msg; return *this; } private: enum StatusCode status_code_; std::string status_msg_; }; #define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now(); #define MSI_TIME_STAMP_END(name) \ { \ auto time_end_##name = std::chrono::steady_clock::now(); \ auto time_cost = std::chrono::duration(time_end_##name - time_start_##name).count(); \ MSI_LOG_INFO << #name " Time Cost # " << time_cost << " ms ---------------------"; \ } #define MSI_TIME_STAMP_END_EXTRA(name, extra) \ { \ auto time_end_##name = std::chrono::steady_clock::now(); \ auto time_cost = std::chrono::duration(time_end_##name - time_start_##name).count(); \ MSI_LOG_INFO << extra << " " << #name " Time Cost # " << time_cost << " ms ---------------------"; \ } #define INFER_STATUS(code) mindspore::serving::Status(code) < mindspore::serving::LogStream() #define INFER_STATUS_LOG_ERROR(code) mindspore::serving::Status(code) = MSILOG_NOIF(ERROR) #define INFER_STATUS_LOG_WARNING(code) mindspore::serving::Status(code) = MSILOG_NOIF(WARNING) } // namespace mindspore::serving #endif // MINDSPORE_SERVING_STATUS_H ================================================ FILE: mindspore_serving/ccsrc/common/tensor.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/tensor.h" #include #include #include #include "common/log.h" namespace mindspore::serving { Tensor::Tensor() = default; Tensor::Tensor(DataType type, const std::vector &shape, const void *data, size_t data_len) : type_(type), shape_(shape) { (void)set_data(data, data_len); } const uint8_t *Tensor::data() const { if (data_size() == 0) { return nullptr; } return data_.data(); } size_t Tensor::data_size() const { return data_.size(); } bool Tensor::resize_data(size_t data_len) { data_.resize(data_len); return true; } uint8_t *Tensor::mutable_data() { if (data_size() == 0) { return nullptr; } return data_.data(); } // For kMSI_String and kMSI_Bytes void Tensor::clear_bytes_data() { bytes_.clear(); } void Tensor::add_bytes_data(const uint8_t *data, size_t bytes_len) { std::vector bytes(bytes_len); (void)memcpy_s(bytes.data(), bytes.size(), data, bytes_len); bytes_.push_back(std::move(bytes)); } size_t Tensor::bytes_data_size() const { return bytes_.size(); } void Tensor::get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const { MSI_EXCEPTION_IF_NULL(data); MSI_EXCEPTION_IF_NULL(bytes_len); *bytes_len = bytes_[index].size(); if (*bytes_len == 0) { *data = nullptr; } else { *data = bytes_[index].data(); } } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/tensor.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_TENSOR_H #define MINDSPORE_SERVING_TENSOR_H #include #include "common/tensor_base.h" namespace mindspore::serving { class MS_API Tensor : public TensorBase { public: Tensor(); Tensor(DataType type, const std::vector &shape, const void *data, size_t data_len); ~Tensor() = default; void set_data_type(DataType type) override { type_ = type; } DataType data_type() const override { return type_; } void set_shape(const std::vector &shape) override { shape_ = shape; } std::vector shape() const override { return shape_; } const uint8_t *data() const override; size_t data_size() const override; bool resize_data(size_t data_len) override; uint8_t *mutable_data() override; // For kMSI_String and kMSI_Bytes void clear_bytes_data() override; void add_bytes_data(const uint8_t *data, size_t bytes_len) override; size_t bytes_data_size() const override; void get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const override; private: DataType type_ = kMSI_Unknown; std::vector shape_; std::vector data_; // For kMSI_String and kMSI_Bytes std::vector> bytes_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_TENSOR_H ================================================ FILE: mindspore_serving/ccsrc/common/tensor_base.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/tensor_base.h" #include #include #include "common/log.h" #define TENSOR_MAX_ELEMENT_COUNT UINT32_MAX namespace mindspore::serving { TensorBase::TensorBase() = default; TensorBase::~TensorBase() = default; bool TensorBase::set_data(const void *data, size_t data_len) { if (data_size() != data_len) { (void)resize_data(data_len); if (data_len == 0) { MSI_LOG_INFO << "set data to data len 0"; return true; } } if (mutable_data() == nullptr) { MSI_LOG_ERROR << "set data failed, data len " << data_len; return false; } if (data_size() != data_len) { MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len " << data_len; return false; } (void)memcpy_s(mutable_data(), data_size(), data, data_len); return true; } size_t TensorBase::itemsize() const { return GetTypeSize(data_type()); } size_t TensorBase::element_cnt() const { size_t element_num = 1; for (auto dim : shape()) { if (dim <= 0 || TENSOR_MAX_ELEMENT_COUNT / static_cast(dim) < element_num) { return 0; } element_num *= static_cast(dim); } return element_num; } size_t TensorBase::GetTypeSize(DataType type) { const std::map type_size_map{ {kMSI_Bool, sizeof(bool)}, {kMSI_Float64, sizeof(double)}, {kMSI_Int8, sizeof(int8_t)}, {kMSI_Uint8, sizeof(uint8_t)}, {kMSI_Int16, sizeof(int16_t)}, {kMSI_Uint16, sizeof(uint16_t)}, {kMSI_Int32, sizeof(int32_t)}, {kMSI_Uint32, sizeof(uint32_t)}, {kMSI_Int64, sizeof(int64_t)}, {kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)}, }; auto it = type_size_map.find(type); if (it != type_size_map.end()) { return it->second; } return 0; } void TensorBase::assign(const TensorBase &other) { if (is_bytes_val_data()) { clear_bytes_data(); } set_shape(other.shape()); set_data_type(other.data_type()); if (other.is_bytes_val_data()) { for (size_t i = 0; i < other.bytes_data_size(); i++) { const uint8_t *data; size_t data_len; other.get_bytes_data(i, &data, &data_len); add_bytes_data(data, data_len); } } else { (void)set_data(other.data(), other.data_size()); } } LogStream &operator<<(LogStream &stream, DataType data_type) { const std::map type_name_map{ {kMSI_Unknown, "kMSI_Unknown"}, {kMSI_Bool, "kMSI_Bool"}, {kMSI_Int8, "kMSI_Int8"}, {kMSI_Uint8, "kMSI_Uint8"}, {kMSI_Int16, "kMSI_Int16"}, {kMSI_Uint16, "kMSI_Uint16"}, {kMSI_Int32, "kMSI_Int32"}, {kMSI_Uint32, "kMSI_Uint32"}, {kMSI_Int64, "kMSI_Int64"}, {kMSI_Uint64, "kMSI_Uint64"}, {kMSI_Float16, "kMSI_Float16"}, {kMSI_Float32, "kMSI_Float32"}, {kMSI_Float64, "kMSI_Float64"}, {kMSI_Bytes, "kMSI_Bytes"}, {kMSI_String, "kMSI_String"}, }; auto it = type_name_map.find(data_type); if (it != type_name_map.end()) { stream << it->second; } else { stream << "kMSI_Unknown"; } return stream; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/tensor_base.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_TENSOR_BASE_H #define MINDSPORE_SERVING_TENSOR_BASE_H #include #include #include #include #include #include #include "common/log.h" #include "common/status.h" namespace mindspore { namespace serving { enum DataType { kMSI_Unknown = 0, kMSI_Bool = 1, kMSI_Int8 = 2, kMSI_Int16 = 3, kMSI_Int32 = 4, kMSI_Int64 = 5, kMSI_Uint8 = 6, kMSI_Uint16 = 7, kMSI_Uint32 = 8, kMSI_Uint64 = 9, kMSI_Float16 = 10, kMSI_Float32 = 11, kMSI_Float64 = 12, kMSI_String = 13, // for model STRING input kMSI_Bytes = 14, // for image etc. }; class TensorBase; using TensorBasePtr = std::shared_ptr; class MS_API TensorBase : public std::enable_shared_from_this { public: TensorBase(); virtual ~TensorBase(); // For all data type virtual std::vector shape() const = 0; virtual void set_shape(const std::vector &shape) = 0; virtual DataType data_type() const = 0; virtual void set_data_type(DataType type) = 0; // All the following interfaces are not for kMSI_String and kMSI_Bytes virtual const uint8_t *data() const = 0; virtual size_t data_size() const = 0; virtual bool resize_data(size_t data_len) = 0; virtual uint8_t *mutable_data() = 0; // Byte size of a single element. size_t itemsize() const; // Total number of elements. size_t element_cnt() const; // resize and copy data bool set_data(const void *data, size_t data_len); static size_t GetTypeSize(DataType type); // For kMSI_String and kMSI_Bytes virtual void clear_bytes_data() = 0; virtual void add_bytes_data(const uint8_t *data, size_t bytes_len) = 0; virtual size_t bytes_data_size() const = 0; virtual void get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const = 0; // TensorBase(const TensorBase& other) = delete; // TensorBase& operator=(const TensorBase& other) = delete; void assign(const TensorBase &other); bool is_bytes_val_data() const { return data_type() == kMSI_Bytes || data_type() == kMSI_String; } }; class RequestBase { public: RequestBase() = default; virtual ~RequestBase() = default; virtual size_t size() const = 0; virtual const TensorBase *operator[](size_t index) const = 0; }; class ReplyBase { public: ReplyBase() = default; virtual ~ReplyBase() = default; virtual size_t size() const = 0; virtual TensorBase *operator[](size_t index) = 0; virtual const TensorBase *operator[](size_t index) const = 0; virtual TensorBase *add() = 0; virtual void clear() = 0; }; extern MS_API LogStream &operator<<(LogStream &stream, DataType data_type); } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_TENSOR_BASE_H ================================================ FILE: mindspore_serving/ccsrc/common/thread_pool.cc ================================================ /** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/thread_pool.h" #include #include #include #include #include #include namespace mindspore::serving { ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false), idle_thrd_num_(size < 1 ? 1 : size) { for (uint32_t i = 0; i < idle_thrd_num_; ++i) { (void)pool_.emplace_back(ThreadFunc, this); } } ThreadPool::~ThreadPool() { { std::unique_lock lock{m_lock_}; is_stoped_.store(true); cond_var_.notify_all(); } for (std::thread &thd : pool_) { if (thd.joinable()) { try { thd.join(); } catch (const std::system_error &) { } catch (...) { } } } } void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { if (thread_pool == nullptr) { return; } while (!thread_pool->is_stoped_) { std::function task; { std::unique_lock lock{thread_pool->m_lock_}; thread_pool->cond_var_.wait( lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { return; } task = std::move(thread_pool->tasks_.front()); thread_pool->tasks_.pop(); } thread_pool->idle_thrd_num_ -= 1; task(); thread_pool->idle_thrd_num_ += 1; } } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/common/thread_pool.h ================================================ /** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_THREAD_POOL_H_ #define MINDSPORE_SERVING_THREAD_POOL_H_ #include #include #include #include #include #include #include #include #include #include namespace mindspore::serving { using ThreadTask = std::function; class ThreadPool { public: explicit ThreadPool(uint32_t size = 4); ~ThreadPool(); template auto commit(Func &&func, Args &&... args) -> std::future { using retType = decltype(func(args...)); std::future fail_future; if (is_stoped_.load()) { return fail_future; } auto bindFunc = std::bind(std::forward(func), std::forward(args)...); auto task = std::make_shared>(bindFunc); if (task == nullptr) { return fail_future; } std::future future = task->get_future(); { std::lock_guard lock{m_lock_}; (void)tasks_.emplace([task]() { (*task)(); }); } cond_var_.notify_one(); return future; } static void ThreadFunc(ThreadPool *thread_pool); private: std::vector pool_; std::queue tasks_; std::mutex m_lock_; std::condition_variable cond_var_; std::atomic is_stoped_; std::atomic idle_thrd_num_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_THREAD_POOL_H_ ================================================ FILE: mindspore_serving/ccsrc/common/utils.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/utils.h" #include #include #include namespace mindspore::serving::common { Status CheckAddress(const std::string &address, const std::string &server_tag, std::string *ip, uint16_t *port) { Status status; auto position = address.find_last_of(':'); if (position == std::string::npos) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: The format of the " << server_tag << " address '" << address << "' is illegal"; return status; } if (position == 0 || position == address.size() - 1) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: Missing ip or port of the " << server_tag << " address '" << address << "'"; return status; } if (ip != nullptr) { *ip = address.substr(0, position); } try { auto port_number = std::stoi(address.substr(position + 1, address.size())); constexpr int port_min = 1; constexpr int port_max = 65535; if (port_number < port_min || port_number > port_max) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: The port of the " << server_tag << " address '" << address << "' is out of legal range [1 ~ 65535]"; return status; } if (port != nullptr) { *port = static_cast(port_number); } } catch (const std::invalid_argument &) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: The type of " << server_tag << " address '" << address << "' port is not a number"; return status; } catch (const std::out_of_range &) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: The port of the " << server_tag << " address '" << address << "' is out of legal range [1 ~ 65535]"; return status; } return SUCCESS; } bool DirOrFileExist(const std::string &file_path) { int ret = access(file_path.c_str(), 0); return (ret == -1) ? false : true; } } // namespace mindspore::serving::common ================================================ FILE: mindspore_serving/ccsrc/common/utils.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_COMMON_UTILS_H #define MINDSPORE_SERVING_COMMON_UTILS_H #include #include "common/status.h" namespace mindspore::serving::common { static inline std::string GetEnv(const std::string &env_var) { const char *value = ::getenv(env_var.c_str()); if (value == nullptr) { return std::string(); } return std::string(value); } Status CheckAddress(const std::string &address, const std::string &server_tag, std::string *ip, uint16_t *port); bool DirOrFileExist(const std::string &file_path); } // namespace mindspore::serving::common #endif // MINDSPORE_SERVING_COMMON_UTILS_H ================================================ FILE: mindspore_serving/ccsrc/master/dispacther.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/dispacther.h" #include #include "common/proto_tensor.h" #include "master/master_context.h" #include "master/notify_worker/grpc_notify.h" namespace mindspore::serving { Dispatcher::Dispatcher() {} Dispatcher::~Dispatcher() { Clear(); } std::shared_ptr Dispatcher::GetWorkerEndpoint(const RequestSpec &request_spec) const { Status status; if (request_spec.version_number > 0) { auto item = find_if(servable_list_.begin(), servable_list_.end(), [&](const std::shared_ptr &v) { return v->GetServableName() == request_spec.servable_name && v->GetVersionNumber() == request_spec.version_number; }); if (item != servable_list_.end()) { return *item; } return nullptr; } uint64_t max_version_number = 0; std::shared_ptr endpoint = nullptr; for (const auto &item : servable_list_) { if (item->GetServableName() == request_spec.servable_name && max_version_number < item->GetVersionNumber()) { endpoint = item; max_version_number = item->GetVersionNumber(); } } return endpoint; } Status Dispatcher::JudgeInferNum() { auto max_enqueued_requests = MasterContext::Instance()->GetMaxEnqueuedRequests(); if (enqueued_requests_ >= max_enqueued_requests) { return INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: enqueued requests count exceeds the limit " << max_enqueued_requests; } return SUCCESS; } void Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { MSI_EXCEPTION_IF_NULL(reply); (*reply->mutable_servable_spec()) = request.servable_spec(); Status status = JudgeInferNum(); if (status != SUCCESS) { GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); on_finish(); return; } try { auto callback = [this, on_finish]() { on_finish(); this->enqueued_requests_--; }; enqueued_requests_++; status = DispatchAsyncInner(request, reply, callback); } catch (const std::bad_alloc &ex) { MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; } catch (const std::runtime_error &ex) { MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); } catch (const std::exception &ex) { MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); } catch (...) { MSI_LOG(ERROR) << "Serving Error: exception occurred"; } if (status != SUCCESS) { GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); on_finish(); enqueued_requests_--; } } Status Dispatcher::DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { MSI_EXCEPTION_IF_NULL(reply); std::shared_lock lock(servable_shared_lock_); RequestSpec request_spec; GrpcTensorHelper::GetRequestSpec(request, &request_spec); auto endpoint = GetWorkerEndpoint(request_spec); if (endpoint == nullptr) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", servable is not available"; } auto methods = endpoint->GetMethods(); bool find_method = std::any_of(methods.begin(), methods.end(), [&](const ServableMethodInfo &method) { return method.name == request_spec.method_name; }); if (!find_method) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available"; } return endpoint->DispatchAsync(request, reply, on_finish); } Status Dispatcher::UnregisterServableCommon(const std::string &worker_address) { std::unique_lock lock(servable_shared_lock_); std::shared_ptr worker_context = nullptr; for (auto &item : worker_list_) { if (item->GetWorkerAddress() == worker_address) { worker_context = item; break; } } if (worker_context == nullptr) { MSI_LOG_ERROR << "Cannot find worker context of address " << worker_address; return FAILED; } auto servable_spec = worker_context->GetWorkerSpec().servable_spec; std::shared_ptr endpoint = nullptr; for (auto &item : servable_list_) { if (item->GetServableName() == servable_spec.servable_name && item->GetVersionNumber() == servable_spec.version_number) { endpoint = item; break; } } if (endpoint) { endpoint->UnregisterWorker(worker_address); } worker_context->OnExit(); MSI_LOG_INFO << "Unregister worker exit success, worker pid: " << worker_context->GetWorkerPid() << ", worker address: " << worker_context->GetWorkerAddress(); return SUCCESS; } Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply *) { WorkerRegSpec worker_spec; GrpcTensorHelper::ConvertProtoWorkerSpec(request, &worker_spec); auto create_notify_worker = [](const WorkerRegSpec &worker_spec) { std::shared_ptr notify_worker = std::make_shared(worker_spec.worker_address); return notify_worker; }; return RegisterServableCommon(worker_spec, create_notify_worker); } Status Dispatcher::NotifyWorkerExit(const proto::ExitRequest &request, proto::ExitReply *) { return UnregisterServableCommon(request.address()); } void Dispatcher::UnregisterWorkerContext(WorkerContext *worker_context) { MSI_EXCEPTION_IF_NULL(worker_context); std::unique_lock lock(servable_shared_lock_); auto worker_spec = worker_context->GetWorkerSpec(); auto &servable_spec = worker_spec.servable_spec; std::shared_ptr endpoint = nullptr; for (auto &item : servable_list_) { if (item->GetServableName() == servable_spec.servable_name && item->GetVersionNumber() == servable_spec.version_number) { endpoint = item; break; } } if (endpoint) { endpoint->UnregisterWorker(worker_context->GetWorkerAddress()); } } Status Dispatcher::NotifyWorkerNotAlive(WorkerContext *worker_context) { MSI_EXCEPTION_IF_NULL(worker_context); UnregisterWorkerContext(worker_context); worker_context->OnNotAlive(); return SUCCESS; } Status Dispatcher::NotifyWorkerNotAvailable(WorkerContext *worker_context) { MSI_EXCEPTION_IF_NULL(worker_context); UnregisterWorkerContext(worker_context); worker_context->OnNotAvailable(); return SUCCESS; } void Dispatcher::GetModelInfo(const proto::GetModelInfoRequest *request, proto::GetModelInfoReply *reply) { auto &servable_name = request->servable_name(); auto version_number = request->version_number(); for (auto &worker : worker_list_) { auto worker_spec = worker->GetWorkerSpec(); if (worker_spec.servable_spec.servable_name == servable_name && worker_spec.servable_spec.version_number == version_number && worker_spec.servable_spec.own_device) { reply->set_servable_name(servable_name); reply->set_version_number(version_number); GrpcTensorHelper::ConvertModelInfos(worker_spec.servable_spec.models, reply->mutable_model_infos()); return; } } auto status = INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has models declared by declare_model, but parameter 'device_ids'" << " of ServableStartConfig is not set in Serving startup script when the device target is not CPU"; auto error_msg = reply->mutable_error_msg(); error_msg->set_error_code(FAILED); error_msg->set_error_msg(status.StatusMessage()); } bool Dispatcher::OnlyModelStage(const std::string &servable_name) { for (auto &worker : worker_list_) { auto worker_spec = worker->GetWorkerSpec(); if (worker_spec.servable_spec.servable_name != servable_name) { continue; } for (auto &method : worker_spec.servable_spec.methods) { // cppcheck-suppress useStlAlgorithm if (!method.only_model_stage) { return false; } } return true; } return false; } void Dispatcher::Clear() { std::unique_lock lock(servable_shared_lock_); for (auto &endpoint : servable_list_) { endpoint->Clear(); } for (auto &worker : worker_list_) { worker->Clear(); } servable_list_.clear(); worker_list_.clear(); } Status Dispatcher::RegisterServableCommon(const WorkerRegSpec &worker_spec, CreateNotifyWorkerFunc func) { MSI_EXCEPTION_IF_NULL(func); std::unique_lock lock(servable_shared_lock_); std::shared_ptr worker_context = nullptr; for (auto &item : worker_list_) { if (item->GetWorkerPid() == worker_spec.worker_pid) { worker_context = item; break; } } bool ready = true; if (worker_context == nullptr) { worker_context = std::make_shared(); worker_context->UpdateWorkerPid(worker_spec.worker_pid); worker_list_.push_back(worker_context); ready = false; } worker_context->OnWorkerRegRequest(worker_spec, func(worker_spec)); if (ready) { auto status = RegisterWorkerContext(worker_context); if (status != SUCCESS) { MSI_LOG_ERROR << "Registered worker failed"; worker_context->OnStartError("Registered worker failed"); } } return SUCCESS; } Status Dispatcher::NotifyWorkerFailed(const proto::NotifyFailedRequest *request, proto::NotifyFailedReply *reply) { auto worker_pid = request->worker_pid(); auto error_msg = request->error_msg(); MSI_LOG_ERROR << "Worker notify failed, worker pid: " << worker_pid << ", error reported: <" << error_msg << ">"; std::unique_lock lock(servable_shared_lock_); std::shared_ptr worker_context = nullptr; for (auto &item : worker_list_) { if (item->GetWorkerPid() == worker_pid) { worker_context = item; break; } } if (worker_context == nullptr) { worker_context = std::make_shared(); worker_context->UpdateWorkerPid(worker_pid); worker_list_.push_back(worker_context); } worker_context->OnStartError(error_msg); return SUCCESS; } std::shared_ptr Dispatcher::InitWorkerContext(const ServableReprInfo &repr, uint64_t worker_pid) { std::unique_lock lock(servable_shared_lock_); std::shared_ptr worker_context = nullptr; for (auto &item : worker_list_) { if (item->GetWorkerPid() == worker_pid) { worker_context = item; break; } } bool ready = true; if (worker_context == nullptr) { worker_context = std::make_shared(); worker_context->UpdateWorkerPid(worker_pid); worker_list_.push_back(worker_context); ready = false; } worker_context->InitServableReprInfo(repr); if (ready) { auto status = RegisterWorkerContext(worker_context); if (status != SUCCESS) { MSI_LOG_ERROR << "Registered worker failed"; worker_context->OnStartError("Registered worker failed"); } } return worker_context; } Status Dispatcher::RegisterWorkerContext(std::shared_ptr worker_context) { auto worker_spec = worker_context->GetWorkerSpec(); auto &servable_spec = worker_spec.servable_spec; if (servable_spec.servable_name.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name cannot be empty"; } if (servable_spec.version_number <= 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name " << servable_spec.servable_name << " version number " << servable_spec.version_number << " cannot be 0"; } std::shared_ptr endpoint = nullptr; for (auto &item : servable_list_) { if (item->GetServableName() == servable_spec.servable_name && item->GetVersionNumber() == servable_spec.version_number) { endpoint = item; break; } } if (!endpoint) { endpoint = std::make_shared(worker_context->GetServableReprInfo()); servable_list_.push_back(endpoint); } endpoint->RegisterWorker(servable_spec, worker_context); worker_context->OnReady(); return SUCCESS; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/dispacther.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_DISPACTHER_H #define MINDSPORE_SERVING_MASTER_DISPACTHER_H #include #include #include #include #include #include "proto/ms_worker.grpc.pb.h" #include "common/serving_common.h" #include "common/instance.h" #include "common/servable.h" #include "master/notify_worker/base_notify.h" #include "common/grpc_client.h" #include "master/worker_context.h" #include "master/servable_endpoint.h" namespace mindspore::serving { class Dispatcher { public: Dispatcher(); ~Dispatcher(); void DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish); Status RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply *reply); Status NotifyWorkerExit(const proto::ExitRequest &request, proto::ExitReply *reply); Status NotifyWorkerFailed(const proto::NotifyFailedRequest *request, proto::NotifyFailedReply *reply); Status NotifyWorkerNotAlive(WorkerContext *worker_context); Status NotifyWorkerNotAvailable(WorkerContext *worker_context); void GetModelInfo(const proto::GetModelInfoRequest *request, proto::GetModelInfoReply *reply); void Clear(); std::shared_ptr InitWorkerContext(const ServableReprInfo &repr, uint64_t worker_pid); bool OnlyModelStage(const std::string &servable_name); private: std::vector> servable_list_; std::vector> worker_list_; std::shared_mutex servable_shared_lock_; std::atomic_uint32_t enqueued_requests_ = 0; Status JudgeInferNum(); std::shared_ptr GetWorkerEndpoint(const RequestSpec &request_spec) const; using CreateNotifyWorkerFunc = std::function(const WorkerRegSpec &worker_spec)>; Status RegisterServableCommon(const WorkerRegSpec &worker_spec, CreateNotifyWorkerFunc func); Status UnregisterServableCommon(const std::string &worker_address); Status DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish); Status RegisterWorkerContext(std::shared_ptr worker_context); void UnregisterWorkerContext(WorkerContext *worker_context); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_DISPACTHER_H ================================================ FILE: mindspore_serving/ccsrc/master/grpc/grpc_process.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/grpc/grpc_process.h" #include #include "master/dispacther.h" namespace mindspore { namespace serving { namespace { std::string GetProtorWorkerSpecRepr(const proto::WorkerRegSpec &worker_spec) { std::stringstream str; auto &servable_spec = worker_spec.servable_spec(); str << "{name:" << servable_spec.name() << ", version:" << servable_spec.version_number() << ", method:["; for (int k = 0; k < servable_spec.methods_size(); k++) { str << servable_spec.methods(k).name(); if (k + 1 < servable_spec.methods_size()) { str << ","; } } str << "]}"; return str.str(); } } // namespace void MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, PredictOnFinish on_finish) { dispatcher_->DispatchAsync(*request, reply, on_finish); } grpc::Status MSMasterImpl::Register(const proto::RegisterRequest *request, proto::RegisterReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); auto worker_sig = [request]() { std::stringstream str; str << "worker address: " << request->worker_spec().address() << ", servable: "; str << GetProtorWorkerSpecRepr(request->worker_spec()); return str.str(); }; Status status(FAILED); status = dispatcher_->RegisterServable(*request, reply); if (status != SUCCESS) { MSI_LOG_ERROR << "Register servable failed, " << worker_sig(); return grpc::Status::OK; } MSI_LOG(INFO) << "Register success: " << worker_sig(); return grpc::Status::OK; } grpc::Status MSMasterImpl::Exit(const proto::ExitRequest *request, proto::ExitReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); auto worker_sig = [request]() { std::stringstream str; str << "worker address: " << request->address(); return str.str(); }; MSI_LOG(INFO) << "Worker Exit, " << worker_sig(); Status status = dispatcher_->NotifyWorkerExit(*request, reply); if (status != SUCCESS) { MSI_LOG_ERROR << "UnRegister servable failed, " << worker_sig(); return grpc::Status::OK; } return grpc::Status::OK; } grpc::Status MSMasterImpl::NotifyFailed(const proto::NotifyFailedRequest *request, proto::NotifyFailedReply *reply) { dispatcher_->NotifyWorkerFailed(request, reply); return grpc::Status::OK; } grpc::Status MSMasterImpl::GetModelInfo(const proto::GetModelInfoRequest *request, proto::GetModelInfoReply *reply) { dispatcher_->GetModelInfo(request, reply); return grpc::Status::OK; } void MSMasterImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { dispatcher_->DispatchAsync(*request, reply, on_finish); } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/grpc/grpc_process.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_GRPC_PROCESS_H #define MINDSPORE_SERVING_MASTER_GRPC_PROCESS_H #include #include #include #include #include #include "common/serving_common.h" #include "common/heart_beat.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "proto/ms_master.pb.h" #include "proto/ms_master.grpc.pb.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "master/dispacther.h" namespace mindspore { namespace serving { // Service Implement class MSServiceImpl { public: explicit MSServiceImpl(std::shared_ptr dispatcher) : dispatcher_(dispatcher) {} ~MSServiceImpl() = default; void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, PredictOnFinish on_finish); private: std::shared_ptr dispatcher_; }; // Service Implement class MSMasterImpl { public: explicit MSMasterImpl(std::shared_ptr dispatcher) : dispatcher_(dispatcher) {} ~MSMasterImpl() = default; grpc::Status Register(const proto::RegisterRequest *request, proto::RegisterReply *reply); grpc::Status Exit(const proto::ExitRequest *request, proto::ExitReply *reply); grpc::Status NotifyFailed(const proto::NotifyFailedRequest *request, proto::NotifyFailedReply *reply); grpc::Status GetModelInfo(const proto::GetModelInfoRequest *request, proto::GetModelInfoReply *reply); void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, const PredictOnFinish &on_finish); private: std::shared_ptr dispatcher_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_GRPC_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/master/grpc/grpc_server.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/grpc/grpc_server.h" #include #include #include "common/grpc_async_server.h" namespace mindspore { namespace serving {} // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/grpc/grpc_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_GRPC_SERVER_H #define MINDSPORE_SERVING_MASTER_GRPC_SERVER_H #include #include #include #include "common/serving_common.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "master/grpc/grpc_process.h" namespace mindspore { namespace serving { template class ServiceGrpcContext : public GrpcAsyncServiceContext { public: ServiceGrpcContext(MSServiceImpl *service_impl, proto::MSService::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : GrpcAsyncServiceContext(service_impl, async_service, cq) {} virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; }; class ServicePredictContext : public ServiceGrpcContext { public: ServicePredictContext(MSServiceImpl *service_impl, proto::MSService::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : ServiceGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~ServicePredictContext() = default; void StartEnqueueRequest() override { async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { MSI_TIME_STAMP_START(RequestHandle) auto instance_size = request_.instances_size(); PredictOnFinish on_finish = [this, time_start_RequestHandle, instance_size]() { responder_.Finish(response_, grpc::Status::OK, this); MSI_TIME_STAMP_END_EXTRA(RequestHandle, "Request count " + std::to_string(instance_size)) }; service_impl_->PredictAsync(&request_, &response_, on_finish); } private: grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; }; class ServiceGrpcServer : public GrpcAsyncServer { public: explicit ServiceGrpcServer(std::shared_ptr dispatcher) : GrpcAsyncServer(), service_impl_(MSServiceImpl(dispatcher)) {} ~ServiceGrpcServer() {} void EnqueueRequests() override { ServicePredictContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); } protected: MSServiceImpl service_impl_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_GRPC_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/master/grpc/master_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_MASTER_SERVER_H #define MINDSPORE_SERVING_MASTER_MASTER_SERVER_H #include #include #include #include "common/serving_common.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "master/grpc/grpc_process.h" namespace mindspore { namespace serving { template class MasterGrpcContext : public GrpcAsyncServiceContext { public: MasterGrpcContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : GrpcAsyncServiceContext(service_impl, async_service, cq) { } virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; }; class MasterRegisterContext : public MasterGrpcContext { public: MasterRegisterContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : MasterGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~MasterRegisterContext() = default; void StartEnqueueRequest() override { async_service_->RequestRegister(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->Register(&request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::RegisterRequest request_; proto::RegisterReply response_; }; class MasterExitContext : public MasterGrpcContext { public: MasterExitContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : MasterGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~MasterExitContext() = default; void StartEnqueueRequest() override { async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->Exit(&request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::ExitRequest request_; proto::ExitReply response_; }; class MasterNotifyFailedContext : public MasterGrpcContext { public: MasterNotifyFailedContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : MasterGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~MasterNotifyFailedContext() = default; void StartEnqueueRequest() override { async_service_->RequestNotifyFailed(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->NotifyFailed(&request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::NotifyFailedRequest request_; proto::NotifyFailedReply response_; }; class MasterGetModelInfoContext : public MasterGrpcContext { public: MasterGetModelInfoContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : MasterGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~MasterGetModelInfoContext() = default; void StartEnqueueRequest() override { async_service_->RequestGetModelInfo(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->GetModelInfo(&request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::GetModelInfoRequest request_; proto::GetModelInfoReply response_; }; class MasterPredictContext : public MasterGrpcContext { public: MasterPredictContext(MSMasterImpl *service_impl, proto::MSMaster::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : MasterGrpcContext(service_impl, async_service, cq), responder_(&ctx_) {} ~MasterPredictContext() = default; void StartEnqueueRequest() override { async_service_->RequestCallModel(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); }; service_impl_->PredictAsync(&request_, &response_, on_finish); } private: grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; }; class MasterGrpcServer : public GrpcAsyncServer { public: explicit MasterGrpcServer(std::shared_ptr dispatcher) : GrpcAsyncServer(), service_impl_(MSMasterImpl(dispatcher)) {} ~MasterGrpcServer() {} void EnqueueRequests() override { MasterRegisterContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); MasterExitContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); MasterNotifyFailedContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); MasterGetModelInfoContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); MasterPredictContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); } protected: MSMasterImpl service_impl_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_MASTER_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/master/master_context.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/master_context.h" namespace mindspore::serving { std::shared_ptr MasterContext::Instance() { static std::shared_ptr instance = nullptr; if (instance == nullptr) { instance = std::make_shared(); } return instance; } void MasterContext::SetMaxEnqueuedRequests(uint32_t max_enqueued_requests) { max_enqueued_requests_ = max_enqueued_requests; } uint32_t MasterContext::GetMaxEnqueuedRequests() const { return max_enqueued_requests_; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/master_context.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_CONTEXT_H #define MINDSPORE_SERVING_MASTER_CONTEXT_H #include #include #include #include "common/serving_common.h" namespace mindspore::serving { class MS_API MasterContext { public: static std::shared_ptr Instance(); void SetMaxEnqueuedRequests(uint32_t max_enqueued_requests); uint32_t GetMaxEnqueuedRequests() const; private: uint32_t max_enqueued_requests_ = 10000; // default 10000 }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_CONTEXT_H ================================================ FILE: mindspore_serving/ccsrc/master/model_thread.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/model_thread.h" #include "common/proto_tensor.h" namespace mindspore::serving { ModelThread::ModelThread(const std::string &servable_name, const std::string &method_name, uint64_t version_number, uint64_t batch_size, ServableMethodInfo method_info) { spec_.servable_name = servable_name; spec_.method_name = method_name; spec_.version_number = version_number; method_info_ = method_info; batch_size_ = batch_size; } void ModelThread::Clear() { std::unique_lock lock(lock_); InnerClear(); } void ModelThread::InnerClear() { for (auto &job_item : job_) { auto reply = job_item.second.reply; bool has_reply = false; bool has_error = false; proto::ErrorMsg detect_error; proto::ErrorMsg exit_error; RequestSpec request_spec; GrpcTensorHelper::GetRequestSpec(*job_item.second.request, &request_spec); auto status = INFER_STATUS(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", servable is not available"; exit_error.set_error_code(status.StatusCode()); exit_error.set_error_msg(status.StatusMessage()); for (auto &task_item : job_item.second.task) { auto instance = reply->add_instances(); auto error = reply->add_error_msg(); if (task_item.error.error_code() != 0) { *error = task_item.error; if (!has_error) { has_error = true; detect_error = task_item.error; } } else if (task_item.output != nullptr) { *instance = *task_item.output; has_reply = true; } else { *error = exit_error; } } if (!has_error && !has_reply) { job_item.second.reply->clear_instances(); job_item.second.reply->clear_error_msg(); auto error_msg = job_item.second.reply->add_error_msg(); *error_msg = exit_error; } else if (!has_reply) { job_item.second.reply->clear_instances(); job_item.second.reply->clear_error_msg(); auto error_msg = job_item.second.reply->add_error_msg(); *error_msg = detect_error; } job_item.second.callback(); } job_.clear(); pid_process_.clear(); task_wait_queue_ = std::queue>(); worker_wait_map_.clear(); } ModelThread::~ModelThread() { Clear(); } Status ModelThread::AddWorker(uint64_t pid, const std::shared_ptr ¬ify) { { std::unique_lock lock(lock_); auto it = pid_process_.find(pid); if (it != pid_process_.end()) { MSI_LOG(INFO) << "pid is existed: " << pid; return FAILED; } pid_process_.insert(std::make_pair(pid, notify)); if (single_batch_dispatch_) { worker_wait_map_[pid] = static_cast(round_ * batch_size_); } else { worker_wait_map_[pid] = static_cast(round_); } } SendTasks(); return SUCCESS; } Status ModelThread::DelWorker(uint64_t pid) { { std::unique_lock lock(lock_); auto it = pid_process_.find(pid); if (it == pid_process_.end()) { MSI_LOG(INFO) << "pid not existed: " << pid; return FAILED; } (void)pid_process_.erase(it); auto worker_it = worker_wait_map_.find(pid); if (worker_it == worker_wait_map_.end()) { MSI_LOG(INFO) << "pid not existed in worker wait map: " << pid; return FAILED; } (void)worker_wait_map_.erase(worker_it); for (auto &job_item : job_) { auto job_id = job_item.first; auto &task_list = job_item.second.task; for (size_t i = 0; i < task_list.size(); ++i) { if (task_list[i].pid == pid) { auto task_id = i; task_wait_queue_.push(std::make_pair(job_id, task_id)); } } } if (pid_process_.empty()) { InnerClear(); } } SendTasks(); return SUCCESS; } Status ModelThread::FindProcessQueue(uint64_t *pid) { int64_t max_free_slot = 0; uint64_t cur_pid = 0; for (auto &item : worker_wait_map_) { auto slot = item.second; if (slot <= 0 || slot < max_free_slot) { continue; } if (slot > max_free_slot || (cur_pid <= last_worker_pid_ && item.first > last_worker_pid_)) { max_free_slot = slot; cur_pid = item.first; } } if (cur_pid != 0) { worker_wait_map_[cur_pid]--; last_worker_pid_ = cur_pid; *pid = cur_pid; return SUCCESS; } return FAILED; } Status ModelThread::PushTasks(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &callback) { auto status = GrpcTensorHelper::CheckRequestInstances(request, method_info_.input_names); if (status != SUCCESS) { MSI_LOG_ERROR << "Check request failed"; return status; } std::unique_lock lock(lock_); if (pid_process_.empty()) { RequestSpec request_spec; GrpcTensorHelper::GetRequestSpec(request, &request_spec); return INFER_STATUS_LOG_ERROR(SERVABLE_UNAVAILABLE) << "Request " << request_spec.Repr() << ", servable is not available"; } auto it = job_.find(job_id_); if (it != job_.end()) { MSI_LOG(ERROR) << "job_id has existed: " << job_id_; return FAILED; } int instance_size = request.instances_size(); Job job; job.wait_task_num = instance_size; job.callback = callback; job.request = &request; job.reply = reply; job.task.resize(instance_size); for (int i = 0; i < instance_size; i++) { Task &task = job.task[i]; task.input = &request.instances(i); task.pid = 0; task_wait_queue_.push(std::make_pair(job_id_, i)); } job_.insert(std::make_pair(job_id_, job)); job_id_++; return SUCCESS; } Status ModelThread::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &callback) { auto status = PushTasks(request, reply, callback); if (status != SUCCESS) { MSI_LOG_ERROR << "Push tasks into queue failed"; return status; } SendTasks(); return SUCCESS; } Status ModelThread::Combine(const std::vector> &ids, uint64_t pid, proto::PredictRequest *msg) { std::vector inputs; // ids->inputs for (auto it = begin(ids); it != end(ids); it++) { uint64_t job_id = it->first; uint64_t task_id = it->second; job_[job_id].task[task_id].pid = pid; inputs.push_back(job_[job_id].task[task_id].input); } return GrpcTensorHelper::CreatePredictRequestFromInstances(spec_, inputs, msg); } void ModelThread::SendTasks() { while (true) { std::shared_ptr context; std::shared_ptr worker; { // pop tasks std::unique_lock lock(lock_); if (task_wait_queue_.empty()) { return; } uint64_t pid; auto status = FindProcessQueue(&pid); if (status != SUCCESS) { return; } context = std::make_shared(); std::vector> &inputs = context->inputs; if (single_batch_dispatch_) { inputs.push_back(task_wait_queue_.front()); task_wait_queue_.pop(); } else { for (uint64_t i = 0; i < batch_size_; i++) { if (task_wait_queue_.empty()) { break; } inputs.push_back(task_wait_queue_.front()); task_wait_queue_.pop(); } } context->pid = pid; Combine(inputs, pid, &context->request); // inputs string->InstanceData,task pid status worker = pid_process_[pid]; } // send request PredictOnFinish callback = [context, worker, this]() { bool worker_not_available = false; for (auto &error : context->reply.error_msg()) { if (error.error_code() == WORKER_UNAVAILABLE) { worker_not_available = true; break; } } if (worker_not_available) { worker->NotifyNotAvailable(); } else { Commit(context); } }; auto status = worker->DispatchAsync(context->request, &context->reply, callback); if (status != SUCCESS) { auto error_msg = context->reply.add_error_msg(); error_msg->set_error_code(WORKER_UNAVAILABLE); error_msg->set_error_msg(status.StatusMessage()); worker->NotifyNotAvailable(); } } } void ModelThread::OnTasksFinished(const std::shared_ptr &context) { std::unique_lock lock(lock_); const auto pid = context->pid; const auto &inputs = context->inputs; if (pid_process_.find(pid) != pid_process_.end()) { worker_wait_map_[pid]++; } std::vector error; std::vector output; auto status = GrpcTensorHelper::CreateInstanceFromPredictReply(spec_, context->reply, &error, &output); if (status != SUCCESS) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Get reply failed, servable name: " << spec_.servable_name << ", method name: " << spec_.method_name << ", version number: " << spec_.version_number; } if (!output.empty() && output.size() != inputs.size()) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The instance count " << output.size() << " of reply is not equal to the count " << inputs.size() << " of request"; } if (status != SUCCESS) { output.clear(); error.clear(); proto::ErrorMsg error_msg; error_msg.set_error_code(status.StatusCode()); error_msg.set_error_msg(status.StatusMessage()); error.push_back(error_msg); } for (unsigned int i = 0; i < inputs.size(); i++) { uint64_t task_id = inputs[i].second; uint64_t job_id = inputs[i].first; auto iter2 = job_.find(job_id); if (iter2 == job_.end()) { MSI_LOG_ERROR << "job_id not exist: " << job_id; continue; } auto &job_item = iter2->second; // collect result auto &task_item = job_item.task[task_id]; task_item.pid = 0; if (i < output.size()) { task_item.output = output[i]; } if (error.empty()) { task_item.error.set_error_code(0); } else if (error.size() == 1) { task_item.error = error[0]; } else { task_item.error = error[i]; } job_item.wait_task_num--; job_item.reply_context_list.push_back(context); if (job_item.wait_task_num == 0) { // reply job std::vector out; std::vector error_reply; for (auto &item : job_item.task) { out.push_back(item.output); error_reply.push_back(item.error); } GrpcTensorHelper::CreatePredictReplyFromInstances(*job_item.request, error_reply, out, job_item.reply); job_item.callback(); (void)job_.erase(iter2); } } } void ModelThread::Commit(const std::shared_ptr &context) { OnTasksFinished(context); SendTasks(); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/model_thread.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_MODEL_THREAD_H #define MINDSPORE_SERVING_MASTER_MODEL_THREAD_H #include #include #include #include #include #include #include #include #include "common/serving_common.h" #include "common/instance.h" #include "master/notify_worker/base_notify.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "master/worker_context.h" namespace mindspore::serving { struct Task { const proto::Instance *input = nullptr; const proto::Instance *output = nullptr; proto::ErrorMsg error; uint64_t pid = 0; // 0:not execute or have executed.others: executing }; struct PredictContext { proto::PredictRequest request; proto::PredictReply reply; uint64_t pid; std::vector> inputs; }; struct Job { std::vector task; uint64_t wait_task_num = 0; PredictOnFinish callback; const proto::PredictRequest *request = nullptr; proto::PredictReply *reply = nullptr; std::vector> reply_context_list; }; class ModelThread { public: ModelThread(const std::string &servable_name, const std::string &method_name, uint64_t version_number, uint64_t batch_size, ServableMethodInfo method_info); ~ModelThread(); Status DelWorker(uint64_t pid); Status AddWorker(uint64_t pid, const std::shared_ptr ¬ify); Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &callback); private: std::map> pid_process_; uint64_t last_worker_pid_ = 0; std::map worker_wait_map_; std::queue> task_wait_queue_; std::map job_; uint64_t job_id_ = 0; uint64_t round_ = 3; std::mutex lock_; RequestSpec spec_; ServableMethodInfo method_info_; uint64_t batch_size_; bool single_batch_dispatch_ = false; void Clear(); void InnerClear(); Status FindProcessQueue(uint64_t *pid); Status PushTasks(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &callback); Status Combine(const std::vector> &ids, uint64_t pid, proto::PredictRequest *msg); void OnTasksFinished(const std::shared_ptr &context); void SendTasks(); void Commit(const std::shared_ptr &context); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_MODEL_THREAD_H ================================================ FILE: mindspore_serving/ccsrc/master/notify_worker/base_notify.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_BASE_NOTIFY_H #define MINDSPORE_SERVING_MASTER_BASE_NOTIFY_H #include #include #include #include "common/serving_common.h" #include "common/servable.h" #include "proto/ms_service.pb.h" #include "common/grpc_client.h" namespace mindspore { namespace serving { class MS_API BaseNotifyWorker { public: BaseNotifyWorker() = default; virtual ~BaseNotifyWorker() = default; virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) = 0; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_BASE_NOTIFY_H ================================================ FILE: mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/notify_worker/grpc_notify.h" #include #include #include #include "common/exit_handle.h" #include "common/grpc_server.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { GrpcNotifyWorker::GrpcNotifyWorker(const std::string &worker_address) { worker_address_ = worker_address; std::shared_ptr channel = GrpcServer::CreateChannel(worker_address); stub_ = proto::MSWorker::NewStub(channel); } GrpcNotifyWorker::~GrpcNotifyWorker() = default; Status GrpcNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { if (!stub_) { return INFER_STATUS_LOG_ERROR(WORKER_UNAVAILABLE) << "Predict failed, worker gRPC has not been inited or has already exited, worker address " << worker_address_; } if (!client_) { client_ = std::make_unique(); client_->Start(); } AsyncPredictCallback callback = [reply, on_finish](Status status) { GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); on_finish(); }; client_->PredictAsync(request, reply, stub_.get(), callback, worker_address_); return SUCCESS; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_GRPC_NOTIFY_H #define MINDSPORE_SERVING_MASTER_GRPC_NOTIFY_H #include #include #include #include #include "master/notify_worker/base_notify.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" namespace mindspore { namespace serving { class MS_API GrpcNotifyWorker : public BaseNotifyWorker { public: explicit GrpcNotifyWorker(const std::string &worker_address); ~GrpcNotifyWorker() override; Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) override; private: std::string worker_address_; std::shared_ptr stub_ = nullptr; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_GRPC_NOTIFY_H ================================================ FILE: mindspore_serving/ccsrc/master/restful/http_handle.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/restful/http_handle.h" #include #include #include #include #include "master/restful/http_process.h" #include "master/server.h" namespace mindspore { namespace serving { static std::vector encode_table = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; static std::vector decode_table = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255}; size_t Base64Encode(const uint8_t *input, size_t length, uint8_t *output) { if (length == 0) { return 0; } size_t i, j; for (i = 0, j = 0; i + 3 <= length; i += 3) { output[j++] = encode_table[input[i] >> 2]; output[j++] = encode_table[((input[i] << 4) & 0x30) | (input[i + 1] >> 4)]; output[j++] = encode_table[((input[i + 1] << 2) & 0x3c) | (input[i + 2] >> 6)]; output[j++] = encode_table[input[i + 2] & 0x3f]; } if (i < length) { uint32_t left_num = length - i; if (left_num == 1) { output[j++] = encode_table[input[i] >> 2]; output[j++] = encode_table[(input[i] << 4) & 0x30]; output[j++] = '='; output[j++] = '='; } else { output[j++] = encode_table[input[i] >> 2]; output[j++] = encode_table[((input[i] << 4) & 0x30) | (input[i + 1] >> 4)]; output[j++] = encode_table[(input[i + 1] << 2) & 0x3c]; output[j++] = '='; } } return j; } size_t Base64Decode(const uint8_t *target, size_t target_length, uint8_t *origin) { if (target_length == 0 || target_length % 4 != 0) { return 0; } size_t i, j = 0; uint8_t value[4]; for (i = 0; i < target_length; i += 4) { for (size_t k = 0; k < 4; k++) { value[k] = decode_table[target[i + k]]; } // value[2], value[3]:may be '=' if (value[0] >= 64 || value[1] >= 64) { MSI_LOG_EXCEPTION << "Decode value is not more than max value 64"; } origin[j++] = (value[0] << 2) | (value[1] >> 4); if (value[2] >= 64) { break; } else if (value[3] >= 64) { origin[j++] = (value[1] << 4) | (value[2] >> 2); break; } else { origin[j++] = (value[1] << 4) | (value[2] >> 2); origin[j++] = (value[2] << 6) | value[3]; } } return j; } size_t GetB64TargetSize(size_t origin_len) { size_t target_size = 0; if (origin_len % 3 == 0) { target_size = (origin_len / 3) * 4; } else { target_size = (origin_len / 3 + 1) * 4; } return target_size; } size_t GetB64OriginSize(size_t target_len, size_t tail_size) { size_t origin_length = 0; if (target_len == 0 || target_len % 4 != 0) { return origin_length; } origin_length = 3 * (target_len / 4) - tail_size; return origin_length; } size_t GetTailEqualSize(const std::string &str) { size_t length = str.size(); if (length % 4 != 0) { return UINT32_MAX; } size_t count = 0; if (length >= 1 && str[length - 1] == '=') { count++; } if (length >= 2 && str[length - 2] == '=') { count++; } return count; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/restful/http_handle.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_HTTP_HANDLE_H #define MINDSPORE_SERVING_MASTER_HTTP_HANDLE_H #include #include #include "common/serving_common.h" #include "master/restful/restful_request.h" using nlohmann::json; namespace mindspore { namespace serving { size_t Base64Encode(const uint8_t *input, size_t length, uint8_t *output); size_t Base64Decode(const uint8_t *target, size_t target_length, uint8_t *origin); size_t GetB64TargetSize(size_t origin_len); size_t GetB64OriginSize(size_t target_len, size_t tail_size); size_t GetTailEqualSize(const std::string &str); } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_HTTP_HANDLE_H ================================================ FILE: mindspore_serving/ccsrc/master/restful/http_process.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/restful/http_process.h" #include #include #include #include #include #include #include #include "common/serving_common.h" #include "master/restful/http_handle.h" #include "common/float16.h" #include "master/server.h" using mindspore::serving::proto::Instance; using mindspore::serving::proto::PredictReply; using mindspore::serving::proto::PredictRequest; namespace mindspore { namespace serving { const int BUF_MAX = 0x7FFFFFFF; static const std::map infer_type2_http_type{{DataType::kMSI_Int32, HTTP_DATA_INT}, {DataType::kMSI_Float32, HTTP_DATA_FLOAT}}; static const std::map http_type2_infer_type{{HTTP_DATA_INT, DataType::kMSI_Int32}, {HTTP_DATA_FLOAT, DataType::kMSI_Float32}, {HTTP_DATA_BOOL, DataType::kMSI_Bool}, {HTTP_DATA_STR, DataType::kMSI_String}, {HTTP_DATA_OBJ, DataType::kMSI_Bytes}}; static const std::map str2_infer_type{ {"int8", DataType::kMSI_Int8}, {"int16", DataType::kMSI_Int16}, {"int32", DataType::kMSI_Int32}, {"int64", DataType::kMSI_Int64}, {"uint8", DataType::kMSI_Uint8}, {"uint16", DataType::kMSI_Uint16}, {"uint32", DataType::kMSI_Uint32}, {"uint64", DataType::kMSI_Uint64}, {"fp16", DataType::kMSI_Float16}, {"fp32", DataType::kMSI_Float32}, {"fp64", DataType::kMSI_Float64}, {"float16", DataType::kMSI_Float16}, {"float32", DataType::kMSI_Float32}, {"float64", DataType::kMSI_Float64}, {"bool", DataType::kMSI_Bool}, {"str", DataType::kMSI_String}, {"bytes", DataType::kMSI_Bytes}}; template bool RestfulService::IsString() { return typeid(T).hash_code() == typeid(std::string).hash_code(); } std::string RestfulService::GetString(const uint8_t *ptr, size_t length) { std::string str; for (size_t i = 0; i < length; i++) { str += ptr[i]; } return str; } Status RestfulService::CheckObjTypeMatchShape(DataType data_type, const std::vector &shape) { if (data_type == kMSI_String || data_type == kMSI_Bytes) { size_t elements_nums = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); if (elements_nums != 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, only support scalar when data type is string or bytes, please check 'type' or 'shape'"; } } return SUCCESS; } RequestType RestfulService::GetReqType(const std::string &str) { auto it = std::find(request_type_list_.begin(), request_type_list_.end(), str); if (it == request_type_list_.end()) { return kInvalidType; } if (*it == kInstancesRequest) { return kInstanceType; } return kInvalidType; } std::string RestfulService::GetReqTypeStr(RequestType req_type) { switch (req_type) { case kInstanceType: return kInstancesRequest; default: break; } return ""; } Status RestfulService::CheckObjType(const string &type) { Status status(SUCCESS); auto it = str2_infer_type.find(type); if (it == str2_infer_type.end()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, specified type:'" << type << "' is illegal"; } return status; } DataType RestfulService::GetObjDataType(const json &js) { DataType type = kMSI_Unknown; if (!js.is_object()) { return type; } auto it1 = js.find(kType); if (it1 == js.end()) { type = kMSI_Bytes; } else { auto type_str = it1.value(); auto it2 = str2_infer_type.find(type_str); if (it2 != str2_infer_type.end()) { type = it2->second; } } return type; } std::string RestfulService::GetStringByDataType(DataType type) { for (const auto &item : str2_infer_type) { // cppcheck-suppress useStlAlgorithm if (item.second == type) { return item.first; } } return ""; } bool RestfulService::JsonMatchDataType(const json &js, DataType type) { bool flag = false; if (js.is_number_integer()) { if (type >= kMSI_Int8 && type <= kMSI_Uint64) { flag = true; } } else if (js.is_number_float()) { if (type >= kMSI_Float16 && type <= kMSI_Float64) { flag = true; } } else if (js.is_string()) { if (type == kMSI_String) { flag = true; } } else if (js.is_boolean()) { if (type == kMSI_Bool) { flag = true; } } return flag; } std::vector RestfulService::GetObjShape(const json &js) { std::vector shape; auto it = js.find(kShape); if (it != js.end()) { shape = GetSpecifiedShape(it.value()); } return shape; } std::vector RestfulService::GetArrayShape(const json &json_array) { std::vector json_shape; const json *tmp_json = &json_array; while (tmp_json->is_array()) { if (tmp_json->empty()) { break; } (void)json_shape.emplace_back(tmp_json->size()); tmp_json = &tmp_json->at(0); } return json_shape; } std::vector RestfulService::GetSpecifiedShape(const json &js) { std::vector shape; if (!js.is_array()) { return shape; } if (js.empty()) { return shape; } for (size_t i = 0; i < js.size(); i++) { auto &item = js.at(i); if (!item.is_number_unsigned()) { return {}; } else { shape.push_back(item.get()); } } return shape; } DataType RestfulService::GetArrayDataType(const json &json_array, HTTP_DATA_TYPE *type_format_ptr) { MSI_EXCEPTION_IF_NULL(type_format_ptr); auto &type_format = *type_format_ptr; DataType data_type = kMSI_Unknown; const json *tmp_json = &json_array; while (tmp_json->is_array()) { if (tmp_json->empty()) { return data_type; } tmp_json = &tmp_json->at(0); } if (tmp_json->is_number_integer()) { type_format = HTTP_DATA_INT; data_type = http_type2_infer_type.at(type_format); } else if (tmp_json->is_number_float()) { type_format = HTTP_DATA_FLOAT; data_type = http_type2_infer_type.at(type_format); } else if (tmp_json->is_boolean()) { type_format = HTTP_DATA_BOOL; data_type = http_type2_infer_type.at(type_format); } else if (tmp_json->is_object()) { type_format = HTTP_DATA_OBJ; data_type = GetObjDataType(*tmp_json); } else if (tmp_json->is_string()) { type_format = HTTP_DATA_STR; data_type = http_type2_infer_type.at(type_format); } return data_type; } Status RestfulService::CheckReqJsonValid(const json &js_msg) { int count = 0; for (auto &item : request_type_list_) { auto it = js_msg.find(item); if (it != js_msg.end()) { count++; auto request_type = GetReqType(item); if (request_type == kInvalidType) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "only support instances mode"; } request_type_ = request_type; } } if (count != 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "key 'instances' expects to exist once, but actually " << count << " times"; } return SUCCESS; } Status RestfulService::GetInstancesType(const json &instances) { Status status{SUCCESS}; // Eg:{"instances" : 1} if (!(instances.is_array() || instances.is_object())) { instances_type_ = kNokeyWay; return status; } // Eg:{"instances":{"A":1, "B":2}} if (instances.is_object()) { instances_type_ = kKeyWay; return status; } // array: if (instances.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "instances value is array type, but no value"; } auto first_instance = instances.at(0); if (first_instance.is_object()) { instances_type_ = kKeyWay; } else { instances_type_ = kNokeyWay; } return status; } Status RestfulService::CheckObj(const json &js) { if (!js.is_object()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json is not object"; } if (js.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, value is empty"; } // 1)required:b64 2)optional:type 3)optional:shape if (js.size() > 3) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, items size is more than 3, only support specified ['b64', 'type', 'shape']"; } int b64_count = 0; int shape_count = 0; int type_count = 0; for (auto item = js.begin(); item != js.end(); ++item) { const auto &key = item.key(); auto value = item.value(); if (key == kB64) { b64_count++; } else if (key == kType) { if (!value.is_string()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is 'type', value should be string type"; } auto status = CheckObjType(value); if (status != SUCCESS) { return status; } type_count++; } else if (key == kShape) { if (!value.is_array()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is 'shape', value should be array type"; } bool zero_dims_before = false; for (auto it = value.begin(); it != value.end(); ++it) { if (zero_dims_before) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is 'shape', invalid shape value " << value.dump(); } if (!(it->is_number_unsigned())) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is 'shape', array value should be unsigned integer"; } auto number = it->get(); if (number < 0) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is 'shape', number value should not be negative number, shape value: " << value.dump(); } if (number == 0) { zero_dims_before = true; } } shape_count++; } else { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, key is not ['b64', 'type', 'shape'], fail key:" << key; } } if (b64_count != 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, 'b64' should be specified only one time"; } if (type_count > 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, 'type' should be specified no more than one time"; } if (shape_count > 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, 'shape' should be specified no more than one time"; } return SUCCESS; } Status RestfulService::ParseItemScalar(const json &value, ProtoTensor *const pb_tensor) { Status status(SUCCESS); std::vector scalar_shape = {}; if (value.is_number_integer()) { DataType type = kMSI_Int32; pb_tensor->set_data_type(type); pb_tensor->set_shape(scalar_shape); pb_tensor->resize_data(pb_tensor->GetTypeSize(type)); status = GetScalarByType(type, value, 0, pb_tensor); } else if (value.is_number_float()) { DataType type = kMSI_Float32; pb_tensor->set_data_type(type); pb_tensor->set_shape(scalar_shape); pb_tensor->resize_data(pb_tensor->GetTypeSize(type)); status = GetScalarByType(type, value, 0, pb_tensor); } else if (value.is_boolean()) { DataType type = kMSI_Bool; pb_tensor->set_data_type(type); pb_tensor->set_shape(scalar_shape); pb_tensor->resize_data(pb_tensor->GetTypeSize(type)); status = GetScalarByType(type, value, 0, pb_tensor); } else if (value.is_string()) { DataType type = kMSI_String; pb_tensor->set_data_type(type); pb_tensor->set_shape(scalar_shape); status = GetScalarByType(type, value, 0, pb_tensor); } else if (value.is_null()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json value is null, it is not supported"; } else if (value.is_discarded()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json value is discarded type, it is not supported"; } else { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json value type is unregistered"; } return status; } Status RestfulService::ParseItemObject(const json &value, ProtoTensor *const pb_tensor) { auto status = CheckObj(value); if (status != SUCCESS) { return status; } DataType type = GetObjDataType(value); if (type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, type is unknown"; } std::vector shape = GetObjShape(value); bool is_tensor = false; if (type != kMSI_String && type != kMSI_Bytes) { is_tensor = true; } if (is_tensor) { size_t shape_size = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); size_t type_size = pb_tensor->GetTypeSize(type); pb_tensor->resize_data(shape_size * type_size); } status = CheckObjTypeMatchShape(type, shape); if (status != SUCCESS) { return status; } pb_tensor->set_data_type(type); pb_tensor->set_shape(shape); status = GetScalarByType(serving::kMSI_Bytes, value[kB64], 0, pb_tensor); return status; } Status RestfulService::ParseItemArray(const json &value, ProtoTensor *const pb_tensor) { HTTP_DATA_TYPE type_format = HTTP_DATA_NONE; auto shape = GetArrayShape(value); if (shape.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, shape is empty"; } DataType data_type = GetArrayDataType(value, &type_format); if (data_type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, data type is unknown"; } bool is_tensor = false; if (data_type != kMSI_String && data_type != kMSI_Bytes) { is_tensor = true; } // instances mode:only support one item if (request_type_ == kInstanceType) { if (!is_tensor) { size_t elements_nums = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); if (elements_nums != 1) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, string or bytes type only support one item"; } } } // set real data type pb_tensor->set_data_type(data_type); pb_tensor->set_shape(shape); if (is_tensor) { size_t shape_size = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); size_t type_size = pb_tensor->GetTypeSize(data_type); pb_tensor->resize_data(shape_size * type_size); } if (type_format == HTTP_DATA_OBJ) { if (data_type != kMSI_Bytes && data_type != kMSI_String) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, item is object type, object only support string or bytes type"; } } return RecursiveGetArray(value, 0, 0, type_format, pb_tensor); } // 1. parse request common func Status RestfulService::ParseItem(const json &value, ProtoTensor *const pb_tensor) { if (value.is_object()) { return ParseItemObject(value, pb_tensor); } else if (value.is_array()) { return ParseItemArray(value, pb_tensor); } else { return ParseItemScalar(value, pb_tensor); } } Status RestfulService::RecursiveGetArray(const json &json_data, size_t depth, size_t data_index, HTTP_DATA_TYPE type_format, ProtoTensor *const request_tensor) { Status status(SUCCESS); std::vector required_shape = request_tensor->shape(); if (depth >= required_shape.size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "invalid json array: current depth " << depth << " is more than shape dims " << required_shape.size(); } if (!json_data.is_array()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "invalid json array: json type is not array"; } if (json_data.size() != static_cast(required_shape[depth])) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "invalid json array: json size is " << json_data.size() << ", the dim " << depth << " expected to be " << required_shape[depth]; } if (depth + 1 < required_shape.size()) { size_t sub_element_cnt = std::accumulate(required_shape.begin() + depth + 1, required_shape.end(), 1LL, std::multiplies()); for (size_t k = 0; k < json_data.size(); k++) { status = RecursiveGetArray(json_data[k], depth + 1, data_index + sub_element_cnt * k, type_format, request_tensor); if (status != SUCCESS) { return status; } } } else { status = GetArrayData(json_data, data_index, type_format, request_tensor); if (status != SUCCESS) { return status; } } return status; } Status RestfulService::GetArrayData(const json &js, size_t data_index, HTTP_DATA_TYPE type, ProtoTensor *const request_tensor) { Status status(SUCCESS); size_t element_nums = js.size(); if (type != HTTP_DATA_OBJ) { for (size_t k = 0; k < element_nums; k++) { auto &json_data = js[k]; if (!(json_data.is_number() || json_data.is_boolean() || json_data.is_string())) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, data should be number, bool, string or bytes"; } auto flag = JsonMatchDataType(json_data, request_tensor->data_type()); if (!flag) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, elements type is not equal"; } status = GetScalarByType(request_tensor->data_type(), json_data, data_index + k, request_tensor); if (status != SUCCESS) { return status; } } } else { for (size_t k = 0; k < element_nums; k++) { auto &json_data = js[k]; auto value_type = GetObjDataType(json_data); // Array:object only support string or bytes if (value_type != kMSI_String && value_type != kMSI_Bytes) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, object type only support string or bytes type"; } if (value_type != request_tensor->data_type()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, elements type is not equal"; } status = GetScalarByType(value_type, json_data[kB64], data_index + k, request_tensor); if (status != SUCCESS) { return status; } } } return status; } Status RestfulService::GetScalarByType(DataType type, const json &js, size_t index, ProtoTensor *const request_tensor) { Status status(SUCCESS); if (type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "data type is unknown"; } switch (type) { case kMSI_Bool: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Int8: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Int16: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Int32: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Int64: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Uint8: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Uint16: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Uint32: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Uint64: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Float16: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Float32: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Float64: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_String: status = GetScalarData(js, index, false, request_tensor); break; case kMSI_Bytes: status = GetScalarData(js, index, true, request_tensor); break; default: auto type_str = GetStringByDataType(type); return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "data type:" << type_str << " is not supported"; } return status; } template Status RestfulService::GetScalarData(const json &js, size_t index, bool is_bytes, ProtoTensor *const request_tensor) { Status status(SUCCESS); if (IsString()) { // 1.string if (!js.is_string()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "get scalar data failed, type is string, but json is not string type"; } auto value = js.get(); if (is_bytes) { DataType real_type = request_tensor->data_type(); auto tail_equal_size = GetTailEqualSize(value); if (tail_equal_size == UINT32_MAX) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "'" << value << "' is illegal b64 encode string"; } auto origin_size = GetB64OriginSize(value.length(), tail_equal_size); std::vector buffer(origin_size, 0); auto target_size = Base64Decode(reinterpret_cast(value.data()), value.length(), buffer.data()); if (target_size != origin_size) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "decode base64 failed, size is not matched."; } if (real_type == kMSI_Bytes || real_type == kMSI_String) { request_tensor->add_bytes_data(buffer.data(), origin_size); } else { auto type_size = request_tensor->GetTypeSize(real_type); auto element_cnt = request_tensor->element_cnt(); if (origin_size != type_size * element_cnt) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "size is not matched, decode base64 size:" << origin_size << "; Given info: type:" << GetStringByDataType(real_type) << "; type size:" << type_size << "; element nums:" << element_cnt; } if (origin_size > 0) { auto data = reinterpret_cast(request_tensor->mutable_data()) + index; (void)memcpy_s(data, origin_size, buffer.data(), buffer.size()); } } } else { request_tensor->add_bytes_data(reinterpret_cast(value.data()), value.length()); } } else { DataType data_type = request_tensor->data_type(); auto flag = JsonMatchDataType(js, data_type); if (!flag) { auto type_str = GetStringByDataType(data_type); return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "data type and json type is not matched, data type is:" << type_str; } // 2.number if ((js.is_number() || js.is_boolean())) { // 1)common number auto data = reinterpret_cast(request_tensor->mutable_data()) + index; *data = js.get(); } } return status; } // 2.main void RestfulService::RunRestful(const std::shared_ptr &restful_request) { auto restful_service = std::make_shared(); restful_service->RunRestfulInner(restful_request, restful_service); } void RestfulService::RunRestfulInner(const std::shared_ptr &restful_request, const std::shared_ptr &restful_service) { MSI_TIME_STAMP_START(RunRestful) auto status = ParseRequest(restful_request, &request_); if (status != SUCCESS) { std::string msg = "Parser request failed, " + status.StatusMessage(); restful_request->ErrorMessage(Status(status.StatusCode(), msg)); return; } auto callback = [restful_service, restful_request, time_start_RunRestful]() { nlohmann::json predict_json; Status status; try { status = restful_service->ParseReply(restful_service->reply_, &predict_json); } catch (std::exception &e) { MSI_LOG_ERROR << "Failed to construct the response: " << e.what(); restful_request->ErrorMessage(Status(status.StatusCode(), "Failed to construct the response")); return; } if (status != SUCCESS) { std::string msg = "Failed to construct the response: " + status.StatusMessage(); restful_request->ErrorMessage(Status(status.StatusCode(), msg)); } else { restful_request->RestfulReplay(predict_json.dump()); } MSI_TIME_STAMP_END(RunRestful) }; auto dispatcher = Server::Instance().GetDispatcher(); dispatcher->DispatchAsync(request_, &reply_, callback); } // 3.parse request Status RestfulService::ParseRequest(const std::shared_ptr &restful_request, PredictRequest *const request) { Status status(SUCCESS); // 1. parse common msg status = ParseReqCommonMsg(restful_request, request); if (status != SUCCESS) { return status; } // 2. parse json auto request_ptr = restful_request->decompose_event_request(); auto &js_msg = request_ptr->request_message_; status = CheckReqJsonValid(js_msg); if (status != SUCCESS) { return status; } switch (request_type_) { case kInstanceType: status = ParseInstancesMsg(js_msg, request); break; default: return INFER_STATUS_LOG_ERROR(FAILED) << "restful request only support instances mode"; } return status; } Status RestfulService::ParseReqCommonMsg(const std::shared_ptr &restful_request, PredictRequest *const request) { Status status(SUCCESS); auto request_ptr = restful_request->decompose_event_request(); if (request_ptr == nullptr) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Decompose event request is nullptr"; } request->mutable_servable_spec()->set_name(request_ptr->model_name_); request->mutable_servable_spec()->set_version_number(request_ptr->version_); request->mutable_servable_spec()->set_method_name(request_ptr->service_method_); return status; } Status RestfulService::ParseInstancesMsg(const json &js_msg, PredictRequest *const request) { Status status = SUCCESS; auto type = GetReqTypeStr(request_type_); auto instances = js_msg.find(type); if (instances == js_msg.end()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "instances request json should have instances key word"; } // get instances way:{key, value} or {value} status = GetInstancesType(*instances); if (status != SUCCESS) { return status; } switch (instances_type_) { case kKeyWay: { status = ParseKeyInstances(*instances, request); break; } case kNokeyWay: { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "instances no key mode is not supported"; } case kInvalidWay: { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "invalid request type"; } } return status; } Status RestfulService::ParseKeyInstances(const json &instances, PredictRequest *const request) { Status status(SUCCESS); if (instances.is_object()) { // one instance:{"instances":{"A":1, "B": 2}} if (instances.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json object, value is empty"; } status = PaserKeyOneInstance(instances, request); if (status != SUCCESS) { MSI_LOG_ERROR << "instances:parse one instance failed"; return status; } instances_nums_ = 1; } else { // multi instance:{"instances":[{}, {}]} for (size_t i = 0; i < instances.size(); i++) { auto &instance = instances.at(i); if (!instance.is_object()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, instance is not object type"; } if (instance.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "json array, instance is object type, but no value"; } status = PaserKeyOneInstance(instance, request); if (status != SUCCESS) { return status; } } instances_nums_ = instances.size(); } return status; } // instance_mgs:one instance, type is object Status RestfulService::PaserKeyOneInstance(const json &instance_msg, PredictRequest *const request) { Status status(SUCCESS); auto instance = request->add_instances(); for (auto it = instance_msg.begin(); it != instance_msg.end(); ++it) { auto key = it.key(); if (key.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "string key is empty"; } auto value = it.value(); auto &map_item = *(instance->mutable_items()); proto::Tensor &tensor = map_item[key]; ProtoTensor pb_tensor(&tensor); status = ParseItem(value, &pb_tensor); if (status != SUCCESS) { return status; } } return status; } // 4.parse reply common func Status RestfulService::ParseReplyDetail(const proto::Tensor &tensor, json *const js) { Status status(SUCCESS); const ProtoTensor pb_tensor(const_cast(&tensor)); auto shape = pb_tensor.shape(); if (shape.empty()) { status = ParseScalar(pb_tensor, 0, js); if (status != SUCCESS) { return status; } } else { status = CheckReply(pb_tensor); if (status != SUCCESS) { return status; } status = RecursiveParseArray(pb_tensor, 0, 0, js); if (status != SUCCESS) { return status; } } return status; } Status RestfulService::ParseScalar(const ProtoTensor &pb_tensor, size_t index, json *const js) { Status status(SUCCESS); DataType data_type = pb_tensor.data_type(); if (data_type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(FAILED) << "Data type is unknown"; } switch (data_type) { case kMSI_Bool: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Int8: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Int16: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Int32: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Int64: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Uint8: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Uint16: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Uint32: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Uint64: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Float16: { const float16 *data = reinterpret_cast(pb_tensor.data()) + index; float value = half_to_float(*data); *js = value; break; } case kMSI_Float32: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Float64: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_String: status = ParseScalarData(pb_tensor, false, index, js); break; case kMSI_Bytes: status = ParseScalarData(pb_tensor, true, index, js); break; default: status = INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "reply data type is not supported"; break; } return status; } template Status RestfulService::ParseScalarData(const ProtoTensor &pb_tensor, bool is_bytes, size_t index, json *const js) { Status status(SUCCESS); if (!IsString()) { const T *data = reinterpret_cast(pb_tensor.data()) + index; T value = *data; *js = value; } else if (IsString()) { if (!is_bytes) { auto str_nums = pb_tensor.bytes_data_size(); if (str_nums == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply string, size is 0"; } if (index >= str_nums) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply string, index:" << index << " is more than size:" << str_nums; } std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(index, &ptr, &length); value.resize(length); (void)memcpy_s(value.data(), length, ptr, length); *js = value; } else { auto str_nums = pb_tensor.bytes_data_size(); if (str_nums == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply bytes, size is 0"; } if (index >= str_nums) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply bytes, index:" << index << " is more than size:" << str_nums; } std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(index, &ptr, &length); value.resize(length); (void)memcpy_s(value.data(), length, ptr, length); auto target_size = GetB64TargetSize(length); std::vector buffer(target_size, 0); auto size = Base64Encode(reinterpret_cast(value.data()), value.length(), buffer.data()); if (size != target_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply bytes, size is not matched, expected size:" << target_size << ", encode size:" << size; } std::string str = GetString(buffer.data(), buffer.size()); (*js)[kB64] = str; } } return status; } Status RestfulService::RecursiveParseArray(const ProtoTensor &pb_tensor, size_t depth, size_t pos, json *const out_json) { Status status(SUCCESS); std::vector required_shape = pb_tensor.shape(); if (depth >= required_shape.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "result shape dims is larger than result shape size " << required_shape.size(); } if (depth == required_shape.size() - 1) { if (required_shape[depth] == 0) { // make empty array out_json->push_back(json()); out_json->clear(); } for (int i = 0; i < required_shape[depth]; i++) { out_json->push_back(json()); json &scalar_json = out_json->back(); status = ParseScalar(pb_tensor, pos + i, &scalar_json); if (status != SUCCESS) { return status; } } } else { for (int i = 0; i < required_shape[depth]; i++) { // array: out_json->push_back(json()); json &tensor_json = out_json->back(); size_t sub_element_cnt = std::accumulate(required_shape.begin() + depth + 1, required_shape.end(), 1LL, std::multiplies()); status = RecursiveParseArray(pb_tensor, depth + 1, i * sub_element_cnt + pos, &tensor_json); if (status != SUCCESS) { return status; } } } return status; } Status RestfulService::CheckReply(const ProtoTensor &pb_tensor) { Status status(SUCCESS); DataType data_type = pb_tensor.data_type(); if (data_type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply data type is unknown"; } if (data_type == kMSI_String || data_type == kMSI_Bytes) { auto shape = pb_tensor.shape(); if (shape.size() != 1) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply string or bytes, shape should be 1, given shape size:" << shape.size(); } } return status; } // 5.Parse reply Status RestfulService::ParseReply(const PredictReply &reply, json *const out_json) { Status status(SUCCESS); switch (request_type_) { case kInstanceType: status = ParseInstancesReply(reply, out_json); break; default: return INFER_STATUS_LOG_ERROR(FAILED) << "restful request only support instance mode"; } return status; } Status RestfulService::ParseInstancesReply(const PredictReply &reply, json *const out_json) { Status status(SUCCESS); auto error_size = reply.error_msg_size(); auto reply_size = reply.instances().size(); if (error_size == 1 && reply_size == 0) { (*out_json)[kErrorMsg] = reply.error_msg()[0].error_msg(); return SUCCESS; } if (error_size != 0 && error_size != instances_nums_) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply error size:" << error_size << " is not 0,1 or instances size " << instances_nums_ << ", reply instances size " << reply_size; } if (reply_size != instances_nums_) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply size:" << reply_size << " is not matched request size:" << instances_nums_; } (*out_json)[kInstancesReply] = json(); json &instances_json = (*out_json)[kInstancesReply]; for (int32_t i = 0; i < instances_nums_; i++) { instances_json.push_back(json()); auto &instance = instances_json.back(); if (error_size != 0 && reply.error_msg()[i].error_code() != 0) { instance[kErrorMsg] = reply.error_msg(i).error_msg(); continue; } auto &cur_instance = reply.instances(i); auto &items = cur_instance.items(); if (items.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty"; } for (auto &item : items) { instance[item.first] = json(); auto &value_json = instance[item.first]; status = ParseReplyDetail(item.second, &value_json); if (status != SUCCESS) { return status; } } } return status; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/restful/http_process.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_HTTP_PROCESS_H #define MINDSPORE_SERVING_MASTER_HTTP_PROCESS_H #include #include #include #include #include "proto/ms_service.pb.h" #include "master/dispacther.h" #include "common/proto_tensor.h" #include "master/restful/restful_request.h" using nlohmann::json; using std::string; namespace mindspore { namespace serving { constexpr auto kInstancesRequest = "instances"; constexpr auto kInstancesReply = "instances"; constexpr auto kErrorMsg = "error_msg"; constexpr auto kType = "type"; constexpr auto kShape = "shape"; constexpr auto kB64 = "b64"; enum RequestType { kInstanceType = 0, kInvalidType }; enum InstancesType { kNokeyWay = 0, kKeyWay, kInvalidWay }; enum HTTP_DATA_TYPE { HTTP_DATA_NONE, HTTP_DATA_INT, HTTP_DATA_FLOAT, HTTP_DATA_BOOL, HTTP_DATA_STR, HTTP_DATA_OBJ }; class RestfulService { public: RestfulService() = default; ~RestfulService() = default; static void RunRestful(const std::shared_ptr &restful_request); private: void RunRestfulInner(const std::shared_ptr &restful_request, const std::shared_ptr &restful_service); Status CheckObjTypeMatchShape(DataType data_type, const std::vector &shape); std::string GetString(const uint8_t *ptr, size_t length); Status CheckObj(const json &js); Status CheckObjType(const std::string &type); DataType GetObjDataType(const json &js); std::vector GetObjShape(const json &js); std::vector GetArrayShape(const json &json_array); std::vector GetSpecifiedShape(const json &js); DataType GetArrayDataType(const json &json_array, HTTP_DATA_TYPE *type_format); Status CheckReqJsonValid(const json &js_msg); std::string GetStringByDataType(DataType type); bool JsonMatchDataType(const json &js, DataType type); template Status GetScalarData(const json &js, size_t index, bool is_bytes, ProtoTensor *const request_tensor); Status GetScalarByType(DataType type, const json &js, size_t index, ProtoTensor *const request_tensor); Status RecursiveGetArray(const json &json_data, size_t depth, size_t data_index, HTTP_DATA_TYPE type_format, ProtoTensor *const request_tensor); Status GetArrayData(const json &js, size_t data_index, HTTP_DATA_TYPE type, ProtoTensor *const request_tensor); Status ParseReqCommonMsg(const std::shared_ptr &restful_request, proto::PredictRequest *const request); Status ParseInstancesMsg(const json &js_msg, proto::PredictRequest *const request); Status GetInstancesType(const json &instances); Status ParseKeyInstances(const json &instances, proto::PredictRequest *const request); Status PaserKeyOneInstance(const json &instance_msg, proto::PredictRequest *const request); Status ParseItemScalar(const json &value, ProtoTensor *const pb_tensor); Status ParseItemArray(const json &value, ProtoTensor *const pb_tensor); Status ParseItemObject(const json &value, ProtoTensor *const pb_tensor); Status ParseItem(const json &value, ProtoTensor *const pb_tensor); Status ParseRequest(const std::shared_ptr &restful_request, proto::PredictRequest *const request); Status ParseReply(const proto::PredictReply &reply, json *const out_json); // parse reply:trans RequestReply to http msg RequestType GetReqType(const std::string &str); std::string GetReqTypeStr(RequestType req_type); Status CheckReply(const ProtoTensor &pb_tensor); Status ParseInstancesReply(const proto::PredictReply &reply, json *const out_json); Status ParseReplyDetail(const proto::Tensor &tensor, json *const js); Status ParseScalar(const ProtoTensor &pb_tensor, size_t index, json *const js); Status RecursiveParseArray(const ProtoTensor &pb_tensor, size_t depth, size_t pos, json *const out_json); template Status ParseScalarData(const ProtoTensor &pb_tensor, bool is_bytes, size_t index, json *const js); template bool IsString(); RequestType request_type_{kInvalidType}; InstancesType instances_type_{kInvalidWay}; int64_t instances_nums_{0}; std::vector request_type_list_ = {kInstancesRequest}; proto::PredictRequest request_; proto::PredictReply reply_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_HTTP_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/master/restful/restful_request.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/restful/restful_request.h" #include #include #include #include #include namespace { const char kUrlKeyModel[] = "model"; const char kUrlKeyVersion[] = "version"; const char kUrlSplit[] = "/"; const char kUrlKeyEnd[] = ":"; } // namespace namespace mindspore { namespace serving { DecomposeEvRequest::DecomposeEvRequest(struct evhttp_request *request, int max_msg_size) : event_request_(request), max_msg_size_(max_msg_size) {} std::string DecomposeEvRequest::UrlQuery(const std::string &url, const std::string &key) const { std::string::size_type start_pos(0); if (key == kUrlKeyEnd) { if ((start_pos = url_.find(kUrlKeyEnd)) != std::string::npos) { return url_.substr(start_pos + 1, url_.size()); } } size_t key_size = key.size() + 1; std::string::size_type end_pos(0); if ((start_pos = url.find(key)) != std::string::npos) { end_pos = std::min(url.find(kUrlSplit, start_pos + key_size), url.find(kUrlKeyEnd, start_pos + key_size)); if (end_pos == std::string::npos) { return url.substr(start_pos + key_size); } return url.substr(start_pos + key_size, end_pos - start_pos - key_size); } return ""; } Status DecomposeEvRequest::GetPostMessageToJson() { Status status(SUCCESS); std::string message; size_t input_size = evbuffer_get_length(event_request_->input_buffer); if (input_size == 0) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "http message invalid"; } else if (input_size > max_msg_size_) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "http message is bigger than " << max_msg_size_; } else { message.resize(input_size); auto src_data = evbuffer_pullup(event_request_->input_buffer, -1); if (src_data == nullptr) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "get http message failed."; } if (memcpy_s(message.data(), input_size, src_data, input_size) != EOK) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "copy http message failed."; } } MSI_TIME_STAMP_START(ParseJson) try { request_message_ = nlohmann::json::parse(message); } catch (nlohmann::json::exception &e) { std::string json_exception = e.what(); MSI_LOG_ERROR << "Illegal JSON format." + json_exception; // Remove invalid character that cannot be converted to Json. const std::string find_msg = "invalid literal"; // invalid literal; last read: '{invalid character}' auto find_pos = json_exception.find(find_msg); if (find_pos != std::string::npos) { json_exception = json_exception.substr(0, find_pos + find_msg.size()); } return INFER_STATUS(INVALID_INPUTS) << "Illegal JSON format." + json_exception; } MSI_TIME_STAMP_END(ParseJson) return status; } Status DecomposeEvRequest::CheckRequestMethodValid() { auto cmd = evhttp_request_get_command(event_request_); if (cmd != EVHTTP_REQ_POST) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "http message only support POST right now"; } request_method_ = "POST"; return SUCCESS; } Status DecomposeEvRequest::Decompose() { Status status(SUCCESS); status = CheckRequestMethodValid(); if (status != SUCCESS) { return status; } status = GetPostMessageToJson(); if (status != SUCCESS) { return status; } // eg: /model/resnet/version/1:predict url_ = evhttp_request_get_uri(event_request_); if (url_.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "evhttp url is empty."; } MSI_LOG_INFO << "url_: " << url_; model_name_ = UrlQuery(url_, kUrlKeyModel); if (model_name_.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "please check url, the keyword:[model] must contain."; } MSI_LOG_INFO << "model_name_: " << model_name_; if (url_.find(kUrlKeyVersion) != std::string::npos) { auto version_str = UrlQuery(url_, kUrlKeyVersion); try { auto version = std::stol(version_str); if (version < 0 || version >= UINT32_MAX) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "please check url, version number range failed, request version number " << version_str; } version_ = static_cast(version); } catch (const std::invalid_argument &) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "please check url, the keyword:[version] value invalid, request version number " << version_str; } catch (const std::out_of_range &) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "please check url, version number range failed, request version number " << version_str; } MSI_LOG_INFO << "version_: " << version_; } service_method_ = UrlQuery(url_, kUrlKeyEnd); if (service_method_.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "please check url, the keyword:[service method] must contain."; } MSI_LOG_INFO << "service_method_: " << service_method_; return status; } RestfulRequest::RestfulRequest(std::shared_ptr request) : decompose_event_request_(std::move(request)) {} RestfulRequest::~RestfulRequest() { if (replay_buffer_ != nullptr) { evbuffer_free(replay_buffer_); replay_buffer_ = nullptr; } } Status RestfulRequest::RestfulReplayBufferInit() { replay_buffer_ = evbuffer_new(); if (replay_buffer_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "create restful replay buffer fail"; } return SUCCESS; } Status RestfulRequest::RestfulReplay(const std::string &replay) { if (replay_buffer_ == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "replay_buffer_ is nullptr"; } if (decompose_event_request_ == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "decompose_event_request_ is nullptr"; } auto &request = decompose_event_request_->event_request_; if (request == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "decompose_event_request_->event_request_ is nullptr"; } auto resp_headers = evhttp_request_get_output_headers(request); (void)evhttp_add_header(resp_headers, "Content-Type", "application/json"); (void)evbuffer_add(replay_buffer_, replay.data(), replay.size()); evhttp_send_reply(request, HTTP_OK, "Client", replay_buffer_); return SUCCESS; } void RestfulRequest::ErrorMessage(const Status &status) { std::string out_error_str; try { nlohmann::json error_json = {{"error_msg", status.StatusMessage()}}; out_error_str = error_json.dump(); } catch (nlohmann::json::exception &e) { nlohmann::json error_json = {{"error_msg", "Illegal JSON format."}}; out_error_str = error_json.dump(); } (void)RestfulReplay(out_error_str); } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/restful/restful_request.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_RESTFUL_REQUEST_H #define MINDSPORE_SERVING_MASTER_RESTFUL_REQUEST_H #include #include #include #include #include #include "common/serving_common.h" namespace mindspore { namespace serving { class DecomposeEvRequest { public: explicit DecomposeEvRequest(struct evhttp_request *request, int max_msg_size); ~DecomposeEvRequest() = default; std::string UrlQuery(const std::string &url, const std::string &key) const; Status CheckRequestMethodValid(); Status Decompose(); Status GetPostMessageToJson(); evhttp_request *event_request_; std::string request_method_; std::string model_name_; std::string url_; std::string service_method_; uint32_t version_{}; uint32_t max_msg_size_{}; nlohmann::json request_message_; }; class RestfulRequest { public: explicit RestfulRequest(std::shared_ptr request); ~RestfulRequest(); RestfulRequest(const RestfulRequest &other) = delete; RestfulRequest &operator=(const RestfulRequest &other) = delete; Status RestfulReplayBufferInit(); Status RestfulReplay(const std::string &replay); void ErrorMessage(const Status &status); std::shared_ptr decompose_event_request() { return decompose_event_request_; } private: std::shared_ptr decompose_event_request_{nullptr}; evbuffer *replay_buffer_ = nullptr; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_RESTFUL_REQUEST_H ================================================ FILE: mindspore_serving/ccsrc/master/restful/restful_server.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "openssl/ssl.h" #include "openssl/err.h" #include "event2/bufferevent.h" #include "event2/http.h" #include "event2/bufferevent_ssl.h" #include "master/restful/http_handle.h" #include "master/restful/restful_server.h" #include "common/utils.h" #include "master/restful/http_process.h" namespace mindspore::serving { const std::vector kCiphers = { "ECDHE-RSA-AES128-GCM-SHA256", "ECDHE-ECDSA-AES128-GCM-SHA256", "ECDHE-RSA-AES256-GCM-SHA384", "ECDHE-ECDSA-AES256-GCM-SHA384", "ECDHE-RSA-CHACHA20-POLY1305", "ECDHE-PSK-CHACHA20-POLY1305", "ECDHE-ECDSA-AES128-CCM", "ECDHE-ECDSA-AES256-CCM", "ECDHE-ECDSA-CHACHA20-POLY1305"}; void RestfulServer::Committer(const std::shared_ptr &restful_request) { thread_pool_.commit([restful_request]() { RestfulService::RunRestful(restful_request); }); } void RestfulServer::DispatchEvHttpRequest(evhttp_request *request) { Status status(SUCCESS); auto de_request = std::make_unique(request, max_msg_size_); Status de_status = de_request->Decompose(); auto restful_request = std::make_shared(std::move(de_request)); status = restful_request->RestfulReplayBufferInit(); if (status != SUCCESS) { restful_request->ErrorMessage(status); return; } if (de_status != SUCCESS) { restful_request->ErrorMessage(de_status); return; } Committer(restful_request); } void RestfulServer::EvCallBack(evhttp_request *request, void *arg) { auto *restful_server = static_cast(arg); restful_server->DispatchEvHttpRequest(const_cast(request)); } Status RestfulServer::CreatRestfulServer(int time_out_second) { evthread_use_pthreads(); auto status = InitEvHttp(); if (status != SUCCESS) { return status; } evhttp_set_gencb(event_http_, &EvCallBack, this); evhttp_set_timeout(event_http_, time_out_second); return SUCCESS; } Status RestfulServer::CreatHttpsServer(int time_out_second, const SSLConfig &ssl_config) { InitOpenSSL(); evthread_use_pthreads(); Status status; status = InitEvHttp(); if (status != SUCCESS) { return status; } SSL_CTX *ctx = SSL_CTX_new(SSLv23_method()); SSL_CTX_set_options(ctx, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1); std::string cipher_list = kCiphers[0]; for (size_t index = 1; index < kCiphers.size(); ++index) { cipher_list += ':'; cipher_list += kCiphers[index]; } if (!SSL_CTX_set_cipher_list(ctx, cipher_list.c_str())) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "SSL use set cipher list failed!"; return status; } if (ssl_config.verify_client) { SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); if (!ssl_config.custom_ca.empty() && SSL_CTX_load_verify_locations(ctx, ssl_config.custom_ca.c_str(), nullptr) != 1) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: load root certificate from " << ssl_config.custom_ca << " failed"; return status; } else { if (SSL_CTX_set_default_verify_paths(ctx) != 1) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: set default verify paths failed"; return status; } } } EC_KEY *ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); if (ecdh == nullptr) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: EC_KEY_new_by_curve_name failed"; return status; } if (!SSL_CTX_set_tmp_ecdh(ctx, ecdh)) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: SSL_CTX_set_tmp_ecdh failed"; return status; } status = ServerSetupCerts(ctx, ssl_config); if (status != SUCCESS) { return status; } evhttp_set_bevcb(event_http_, bevcb, ctx); evhttp_set_gencb(event_http_, &EvCallBack, this); evhttp_set_timeout(event_http_, time_out_second); return SUCCESS; } Status RestfulServer::ServerSetupCerts(SSL_CTX *ctx, const SSLConfig &ssl_config) { Status status; if (SSL_CTX_use_certificate_chain_file(ctx, ssl_config.certificate.c_str()) != 1) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: load certificate_chain from " << ssl_config.certificate << " failed"; return status; } if (SSL_CTX_use_PrivateKey_file(ctx, ssl_config.private_key.c_str(), SSL_FILETYPE_PEM) != 1) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: load private_key from " << ssl_config.private_key << " failed"; return status; } if (SSL_CTX_check_private_key(ctx) != 1) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: private_key is not consistent with certificate " << ssl_config.certificate; return status; } return SUCCESS; } struct bufferevent *RestfulServer::bevcb(struct event_base *base, void *args) { struct bufferevent *r; SSL_CTX *ctx = static_cast(args); r = bufferevent_openssl_socket_new(base, -1, SSL_new(ctx), BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); return r; } Status RestfulServer::InitEvHttp() { event_base_ = event_base_new(); Status status(SUCCESS); if (event_base_ == nullptr) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: RESTful server start failed, new http event failed"; return status; } event_http_ = evhttp_new(event_base_); if (event_http_ == nullptr) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: RESTful server start failed, create http server failed"; event_base_free(event_base_); event_base_ = nullptr; return status; } return status; } void RestfulServer::FreeEvhttp() { if (event_http_ != nullptr) { evhttp_free(event_http_); event_http_ = nullptr; } if (event_base_ != nullptr) { event_base_free(event_base_); event_base_ = nullptr; } } void RestfulServer::RunEvhttp() { auto event_http_run = [this]() { MSI_LOG(INFO) << "Serving RESTful server listening on " << socket_address_; std::cout << "Serving: Serving RESTful server start success, listening on " << socket_address_ << std::endl; event_base_dispatch(event_base_); }; event_thread_ = std::thread(event_http_run); } Status RestfulServer::StartRestfulServer() { Status status(SUCCESS); uint16_t port; std::string ip; status = GetSocketAddress(&ip, &port); if (status != SUCCESS) { return status; } auto ret = evhttp_bind_socket(event_http_, ip.c_str(), port); if (ret != 0) { FreeEvhttp(); status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: RESTful server start failed, bind to the socket address " << socket_address_ << " failed"; return status; } RunEvhttp(); return SUCCESS; } Status RestfulServer::GetSocketAddress(std::string *ip, uint16_t *port) { MSI_EXCEPTION_IF_NULL(ip); MSI_EXCEPTION_IF_NULL(port); Status status; std::string prefix = "unix:"; if (socket_address_.substr(0, prefix.size()) == prefix) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: RESTful server does not support binding to unix domain socket"; return status; } status = common::CheckAddress(socket_address_, "RESTful server", ip, port); if (status != SUCCESS) { return status; } return SUCCESS; } Status RestfulServer::Start(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_size, int time_out_second) { Status status(SUCCESS); if (in_running_) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: RESTful server is already running"; } socket_address_ = socket_address; constexpr int mbytes_to_bytes = static_cast(1u << 20); max_msg_size_ = max_msg_size * mbytes_to_bytes; if (ssl_config.use_ssl) { status = CreatHttpsServer(time_out_second, ssl_config); } else { status = CreatRestfulServer(time_out_second); } if (status != SUCCESS) { return status; } status = StartRestfulServer(); if (status != SUCCESS) { return status; } in_running_ = true; return status; } void RestfulServer::Stop() { if (in_running_) { event_base_loopexit(event_base_, nullptr); event_thread_.join(); } in_running_ = false; FreeEvhttp(); } void RestfulServer::InitOpenSSL() { #if (OPENSSL_VERSION_NUMBER < 0x10100000L) || (defined(LIBRESSL_VERSION_NUMBER) && OPENSSL_VERSION_NUMBER < 0x20700000L) SSL_library_init(); ERR_load_crypto_strings(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); #endif } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/restful/restful_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_RESTFUL_SERVER_H #define MINDSPORE_SERVING_MASTER_RESTFUL_SERVER_H #include #include #include #include #include #include #include #include #include #include #include "openssl/ssl.h" #include "openssl/err.h" #include "event2/bufferevent.h" #include "master/restful/restful_request.h" #include "common/serving_common.h" #include "common/thread_pool.h" #include "common/ssl_config.h" namespace mindspore::serving { constexpr const uint32_t kDefaultRestfulThreadPoolNum = 3; class RestfulServer { public: RestfulServer() : thread_pool_(kDefaultRestfulThreadPoolNum) {} ~RestfulServer() { Stop(); } Status Start(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_size, int time_out_second); void Stop(); private: Status CreatRestfulServer(int time_out_second); Status CreatHttpsServer(int time_out_second, const SSLConfig &ssl_config); static void EvCallBack(evhttp_request *request, void *arg); void DispatchEvHttpRequest(evhttp_request *request); void Committer(const std::shared_ptr &restful_request); Status StartRestfulServer(); Status GetSocketAddress(std::string *ip, uint16_t *port); static void InitOpenSSL(); static Status ServerSetupCerts(SSL_CTX *ctx, const SSLConfig &ssl_config); static struct bufferevent *bevcb(struct event_base *base, void *args); Status InitEvHttp(); void FreeEvhttp(); void RunEvhttp(); std::string socket_address_; int max_msg_size_ = 0; bool in_running_ = false; struct evhttp *event_http_ = nullptr; struct event_base *event_base_ = nullptr; std::thread event_thread_; ThreadPool thread_pool_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_RESTFUL_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/master/servable_endpoint.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/servable_endpoint.h" namespace mindspore::serving { ServableEndPoint::ServableEndPoint(const ServableReprInfo &repr) : worker_repr_(repr) { version_number_ = worker_repr_.version_number; } ServableEndPoint::~ServableEndPoint() { Clear(); } Status ServableEndPoint::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { auto method_name = request.servable_spec().method_name(); auto it = model_thread_list_.find(method_name); if (it == model_thread_list_.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find model thread of method " << method_name; } auto status = it->second->DispatchAsync(request, reply, on_finish); return status; } Status ServableEndPoint::RegisterWorker(const ServableRegSpec &servable_spec, std::shared_ptr worker) { auto &methods = servable_spec.methods; // first init if (worker_contexts_.empty()) { methods_ = servable_spec.methods; if (version_number_ == 0) { version_number_ = servable_spec.version_number; } for (auto &method : methods) { if (servable_spec.batch_size <= 0) { MSI_LOG_ERROR << "Register Worker,method batch_size should be greater than 0"; return FAILED; } auto model_thread = std::make_shared(servable_spec.servable_name, method.name, servable_spec.version_number, servable_spec.batch_size, method); (void)model_thread_list_.emplace(method.name, model_thread); } } worker_contexts_.push_back(worker); std::vector method_names; for (auto &method : methods) { auto it = model_thread_list_.find(method.name); if (it == model_thread_list_.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find method " << method.name << " registered before"; } it->second->AddWorker(worker->GetWorkerPid(), worker); // cppcheck-suppress useStlAlgorithm method_names.push_back(method.name); } MSI_LOG_INFO << "Register to servable endpoint success, servable name: " << worker_repr_.servable_name << ", version number: " << servable_spec.version_number << ", methods: " << method_names << ", worker address: " << worker->GetWorkerAddress(); return SUCCESS; } Status ServableEndPoint::UnregisterWorker(const std::string &worker_address) { auto it = std::find_if(worker_contexts_.begin(), worker_contexts_.end(), [worker_address](const std::shared_ptr &item) { return item->GetWorkerAddress() == worker_address; }); if (it != worker_contexts_.end()) { auto worker = *it; MSI_LOG_INFO << "Unregister worker success, " << worker_repr_.repr << ", version number: " << version_number_ << ", worker address: " << worker_address; for (auto &model_thread : model_thread_list_) { model_thread.second->DelWorker(worker->GetWorkerPid()); } (void)worker_contexts_.erase(it); return SUCCESS; } MSI_LOG_INFO << "Worker has already been unregistered, " << worker_repr_.repr << ", version number: " << version_number_ << ", worker address: " << worker_address; return FAILED; } void ServableEndPoint::Clear() { worker_contexts_.clear(); model_thread_list_.clear(); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/servable_endpoint.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_SERVABLE_ENDPOINT_H #define MINDSPORE_SERVING_MASTER_SERVABLE_ENDPOINT_H #include #include #include #include #include #include #include "common/serving_common.h" #include "master/worker_context.h" #include "master/model_thread.h" namespace mindspore::serving { // visit by dispatcher class ServableEndPoint { public: explicit ServableEndPoint(const ServableReprInfo &repr); ~ServableEndPoint(); Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish); Status RegisterWorker(const ServableRegSpec &servable_spec, std::shared_ptr worker); Status UnregisterWorker(const std::string &worker_address); void Clear(); std::string GetServableName() const { return worker_repr_.servable_name; } uint64_t GetVersionNumber() const { return version_number_; } std::vector GetMethods() const { return methods_; } private: std::map> model_thread_list_; ServableReprInfo worker_repr_; std::vector methods_; std::vector> worker_contexts_; uint32_t version_number_ = 0; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_SERVABLE_ENDPOINT_H ================================================ FILE: mindspore_serving/ccsrc/master/server.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/server.h" #include #include #include "common/serving_common.h" #include "master/grpc/grpc_process.h" #include "master/grpc/grpc_server.h" namespace mindspore { namespace serving { Status Server::StartGrpcServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size) { if (grpc_async_server_) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; } if (max_msg_mb_size > gRpcMaxMBMsgSize) { MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size << "MB to 512MB"; max_msg_mb_size = gRpcMaxMBMsgSize; } grpc_async_server_ = std::make_shared(dispatcher_); return grpc_async_server_->Start(socket_address, ssl_config, max_msg_mb_size, "Serving gRPC"); } Status Server::StartGrpcMasterServer(const std::string &master_address) { if (master_async_server_) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Master gRPC server is already running"; } SSLConfig ssl_config; ssl_config.use_ssl = false; master_async_server_ = std::make_shared(dispatcher_); return master_async_server_->Start(master_address, ssl_config, gRpcMaxMBMsgSize, "Master"); } Status Server::StartRestfulServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size, int time_out_second) { return restful_server_.Start(socket_address, ssl_config, max_msg_mb_size, time_out_second); } void Server::Clear() { MSI_LOG_INFO << "Server start to clean"; dispatcher_->Clear(); restful_server_.Stop(); if (master_async_server_) { master_async_server_->Stop(); master_async_server_ = nullptr; } if (grpc_async_server_) { grpc_async_server_->Stop(); grpc_async_server_ = nullptr; } MSI_LOG_INFO << "Server end to clean"; } Server::Server() = default; Server &Server::Instance() { static Server server; return server; } bool Server::OnlyModelStage(const std::string &servable_name) { return dispatcher_->OnlyModelStage(servable_name); } Server::~Server() = default; } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/master/server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_SERVER_H #define MINDSPORE_SERVING_MASTER_SERVER_H #include #include #include "common/serving_common.h" #include "common/grpc_server.h" #include "master/restful/restful_server.h" #include "master/dispacther.h" #include "master/grpc/grpc_server.h" #include "master/grpc/master_server.h" #include "common/ssl_config.h" namespace mindspore { namespace serving { class MS_API Server { public: Server(); ~Server(); Status StartGrpcServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size = gRpcDefaultMsgMBSize); Status StartRestfulServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size = gRpcDefaultMsgMBSize, int time_out_second = 100); Status StartGrpcMasterServer(const std::string &master_address); void Clear(); bool OnlyModelStage(const std::string &servable_name); std::shared_ptr GetDispatcher() { return dispatcher_; } static Server &Instance(); private: std::shared_ptr dispatcher_ = std::make_shared(); std::shared_ptr grpc_async_server_ = nullptr; std::shared_ptr master_async_server_ = nullptr; RestfulServer restful_server_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_MASTER_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/master/worker_context.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "master/worker_context.h" #include "master/servable_endpoint.h" #include "master/server.h" namespace mindspore::serving { // from py std::shared_ptr WorkerContext::PyInitWorkerContext(std::string servable_name, uint32_t version_number, std::string repr, uint64_t worker_pid) { ServableReprInfo servable_repr; servable_repr.servable_name = servable_name; servable_repr.version_number = version_number; servable_repr.repr = repr; return Server::Instance().GetDispatcher()->InitWorkerContext(servable_repr, worker_pid); } // from Dispatcher Status WorkerContext::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { auto shared_this = shared_from_this(); PredictOnFinish callback = [shared_this, on_finish, reply]() { auto &error_msg = reply->error_msg(); auto has_error = std::any_of(error_msg.begin(), error_msg.end(), [](const proto::ErrorMsg &msg) { return msg.error_code() != 0; }); if (!has_error && reply->instances_size() != 0) { shared_this->normal_handled_count += 1; shared_this->total_normal_handled_count += 1; } else { shared_this->abnormal_handled_count += 1; shared_this->total_abnormal_handled_count += 1; } on_finish(); }; std::unique_lock lock(lock_); if (status_ != kWorkerStatusReady && !notify_worker_) { return INFER_STATUS_LOG_ERROR(WORKER_UNAVAILABLE) << "Worker is not ready"; } request_count += 1; return notify_worker_->DispatchAsync(request, reply, callback); } // from worker void WorkerContext::OnWorkerRegRequest(const WorkerRegSpec &worker_spec, std::shared_ptr notify) { std::unique_lock lock(lock_); MSI_LOG_INFO << "Receive worker registered message, " << servable_repr_.repr << ", worker pid: " << worker_pid_ << ", worker address: " << worker_spec.worker_address; worker_spec_ = worker_spec; notify_worker_ = notify; } void WorkerContext::OnReady() { std::unique_lock lock(lock_); MSI_LOG_INFO << "Notify worker ready, " << servable_repr_.repr << ", worker pid: " << worker_pid_ << ", worker address: " << worker_spec_.worker_address; status_ = kWorkerStatusReady; } void WorkerContext::OnExit() { std::unique_lock lock(lock_); MSI_LOG_INFO << "Notify worker exit, " << servable_repr_.repr << ", worker pid: " << worker_pid_ << ", worker address: " << worker_spec_.worker_address; status_ = kWorkerStatusNotifyExit; notify_worker_ = nullptr; } void WorkerContext::OnStartError(const std::string ¬ified_error) { std::unique_lock lock(lock_); MSI_LOG_ERROR << "Notify worker start-up error, " << servable_repr_.repr << ", worker pid: " << worker_pid_; status_ = kWorkerStatusNotifyFailed; notify_worker_ = nullptr; notified_error_ = notified_error; } void WorkerContext::OnNotAvailable() { std::unique_lock lock(lock_); MSI_LOG_ERROR << "Notify worker not available, " << servable_repr_.repr << ", worker pid: " << worker_pid_; if (status_ != kWorkerStatusNotifyExit && status_ != kWorkerStatusNotAlive) { status_ = kWorkerStatusNotAvailable; } notify_worker_ = nullptr; } void WorkerContext::OnNotAlive() { if (HasExitNotified() || HasErrorNotified()) { return; } std::unique_lock lock(lock_); MSI_LOG_INFO << "Notify worker not alive, " << servable_repr_.repr << ", worker pid: " << worker_pid_ << ", worker address: " << worker_spec_.worker_address; if (status_ != kWorkerStatusNotifyExit) { status_ = kWorkerStatusNotAlive; } notify_worker_ = nullptr; } // from py void WorkerContext::PyNotifyNotAlive() { Server::Instance().GetDispatcher()->NotifyWorkerNotAlive(this); } void WorkerContext::PyNotifyStartFailed(const std::string ¬ified_error) { OnStartError(notified_error); } void WorkerContext::NotifyNotAvailable() { Server::Instance().GetDispatcher()->NotifyWorkerNotAvailable(this); } void WorkerContext::UpdateWorkerPid(uint64_t new_worker_pid) { std::unique_lock lock(lock_); MSI_LOG_INFO << "Update worker pid from " << worker_pid_ << " to " << new_worker_pid; if (status_ != kWorkerStatusReady) { status_ = kWorkerStatusStarting; } worker_pid_ = new_worker_pid; normal_handled_count = 0; abnormal_handled_count = 0; } void WorkerContext::Clear() { std::unique_lock lock(lock_); notify_worker_ = nullptr; status_ = kWorkerStatusNotAlive; } bool WorkerContext::OwnDevice() const { return worker_spec_.servable_spec.own_device; } void WorkerContext::PrintStatus() const { auto repr = servable_repr_.repr; switch (status_) { case kWorkerStatusNotAlive: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusNotAlive, " << repr; break; case kWorkerStatusStarting: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusStarting, " << repr; break; case kWorkerStatusReady: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusReady, " << repr; break; case kWorkerStatusNotifyExit: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusNotifyExit, " << repr; break; case kWorkerStatusNotifyFailed: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusNotifyFailed, " << repr; break; case kWorkerStatusNotAvailable: MSI_LOG_INFO << "worker " << GetWorkerPid() << " status is kWorkerStatusNotAvailable, " << repr; break; } } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/master/worker_context.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_MASTER_WORKER_CONTEXT_H #define MINDSPORE_SERVING_MASTER_WORKER_CONTEXT_H #include #include #include #include #include #include "proto/ms_worker.grpc.pb.h" #include "common/serving_common.h" #include "master/notify_worker/base_notify.h" namespace mindspore::serving { class ServableEndPoint; enum WorkerStatus { kWorkerStatusNotAlive = 1, kWorkerStatusStarting, kWorkerStatusReady, kWorkerStatusNotifyExit, kWorkerStatusNotifyFailed, kWorkerStatusNotAvailable, }; struct ServableReprInfo { std::string servable_name; uint32_t version_number = 0; std::string repr; }; class MS_API WorkerContext : public std::enable_shared_from_this { public: WorkerContext() = default; ~WorkerContext() { Clear(); } bool HasErrorNotified() const { return status_ == kWorkerStatusNotifyFailed; } bool HasExitNotified() const { return status_ == kWorkerStatusNotifyExit; } std::string GetNotifiedError() const { return notified_error_; } bool HasReady() const { return status_ == kWorkerStatusReady; } bool IsInStarting() const { return status_ == kWorkerStatusStarting; } bool IsUnavailable() const { return status_ == kWorkerStatusNotAvailable; } void PrintStatus() const; uint64_t GetNormalHandledCount() const { return normal_handled_count; } uint64_t GetWorkerPid() const { return worker_pid_; } WorkerRegSpec GetWorkerSpec() const { return worker_spec_; } ServableReprInfo GetServableReprInfo() const { return servable_repr_; } std::string GetWorkerAddress() const { return worker_spec_.worker_address; } void InitServableReprInfo(const ServableReprInfo &repr) { servable_repr_ = repr; } // from py static std::shared_ptr PyInitWorkerContext(std::string servable_name, uint32_t version_number, std::string repr, uint64_t worker_pid); void PyNotifyNotAlive(); void PyNotifyStartFailed(const std::string ¬ified_error); void NotifyNotAvailable(); void UpdateWorkerPid(uint64_t new_worker_pid); // from Dispatcher Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish); // from worker void OnWorkerRegRequest(const WorkerRegSpec &worker_spec, std::shared_ptr notify); void OnReady(); void OnExit(); void OnStartError(const std::string ¬ified_error); void OnNotAvailable(); // from py void OnNotAlive(); void Clear(); bool OwnDevice() const; private: std::mutex lock_; ServableReprInfo servable_repr_; uint32_t device_id_ = 0; uint64_t worker_pid_ = 0; // from worker register info WorkerRegSpec worker_spec_; std::shared_ptr notify_worker_ = nullptr; // from python env WorkerStatus status_ = kWorkerStatusNotAlive; std::string notified_error_; std::atomic_uint64_t request_count = 0; std::atomic_uint64_t total_normal_handled_count = 0; std::atomic_uint64_t total_abnormal_handled_count = 0; std::atomic_uint64_t normal_handled_count = 0; std::atomic_uint64_t abnormal_handled_count = 0; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_MASTER_WORKER_CONTEXT_H ================================================ FILE: mindspore_serving/ccsrc/python/agent/agent_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "python/agent/agent_py.h" #include "common/exit_handle.h" #include "worker/distributed_worker/agent_startup.h" #include "worker/distributed_worker/worker_agent.h" namespace mindspore::serving { DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &distributed_address) { auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(distributed_address); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } DistributedServableConfig config; status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } return config; } void PyAgent::NotifyFailed(const std::string &distributed_address) { WorkerAgentStartUp::Instance().NotifyFailed(distributed_address); } void PyAgent::StartAgent(const AgentStartUpConfig &start_config, const std::string &dec_key, const std::string &dec_mode) { auto status = WorkerAgent::Instance().StartAgent(start_config, dec_key, dec_mode); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyAgent::WaitAndClear() { { py::gil_scoped_release release; ExitSignalHandle::Instance().AgentWait(); } WorkerAgent::Instance().Clear(); MSI_LOG_INFO << "Python agent end wait and clear"; } void PyAgent::StopAndClear() { ExitSignalHandle::Instance().Stop(); WorkerAgent::Instance().Clear(); } void PyAgent::StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip) { WorkerAgentStartUp::Instance().StartupNotifyExit(distributed_address, agent_ip); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/agent/agent_py.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVER_AGENT_PY_H #define MINDSPORE_SERVER_AGENT_PY_H #include #include #include #include #include #include "common/serving_common.h" #include "worker/distributed_worker/common.h" namespace py = pybind11; namespace mindspore { namespace serving { class MS_API PyAgent { public: static void StartAgent(const AgentStartUpConfig &start_config, const std::string &dec_key, const std::string &dec_mode); static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &distributed_address); static void WaitAndClear(); static void StopAndClear(); // from start up, not agent static void NotifyFailed(const std::string &distributed_address); static void StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVER_AGENT_PY_H ================================================ FILE: mindspore_serving/ccsrc/python/master/master_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "python/master/master_py.h" #include "common/exit_handle.h" #include "master/server.h" namespace mindspore::serving { void PyMaster::StartGrpcServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size) { auto status = Server::Instance().StartGrpcServer(socket_address, ssl_config, max_msg_mb_size); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyMaster::StartGrpcMasterServer(const std::string &master_address) { auto status = Server::Instance().StartGrpcMasterServer(master_address); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyMaster::StartRestfulServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size) { auto status = Server::Instance().StartRestfulServer(socket_address, ssl_config, max_msg_mb_size); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyMaster::WaitAndClear() { { py::gil_scoped_release release; ExitSignalHandle::Instance().MasterWait(); } Server::Instance().Clear(); MSI_LOG_INFO << "Python server end wait and clear"; } void PyMaster::StopAndClear() { ExitSignalHandle::Instance().Stop(); Server::Instance().Clear(); } bool PyMaster::OnlyModelStage(const std::string &servable_name) { return Server::Instance().OnlyModelStage(servable_name); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/master/master_py.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVER_MASTER_PY_H #define MINDSPORE_SERVER_MASTER_PY_H #include #include #include #include #include #include "common/serving_common.h" #include "common/ssl_config.h" namespace py = pybind11; namespace mindspore { namespace serving { class MS_API PyMaster { public: static void StartGrpcServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size = 100); static void StartGrpcMasterServer(const std::string &master_address); static void StartRestfulServer(const std::string &socket_address, const SSLConfig &ssl_config, int max_msg_mb_size = 100); static void WaitAndClear(); static void StopAndClear(); static bool OnlyModelStage(const std::string &servable_name); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVER_MASTER_PY_H ================================================ FILE: mindspore_serving/ccsrc/python/serving_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "python/worker/worker_py.h" #include "python/worker/servable_py.h" #include "python/tensor_py.h" #include "common/servable.h" #include "common/ssl_config.h" #include "master/server.h" #include "master/master_context.h" #include "master/worker_context.h" #include "worker/context.h" #include "worker/stage_function.h" #include "python/master/master_py.h" #include "python/agent/agent_py.h" #include "common/exit_handle.h" #include "worker/distributed_worker/worker_agent.h" namespace mindspore::serving { void PyRegServable(pybind11::module *m_ptr) { auto &m = *m_ptr; // avoid as numpy object memory copy in PyTensor::AsPythonData py::class_(m, "Tensor_"); py::class_>(m, "StageFunctionStorage_") .def(py::init<>()) .def_static("get_instance", &PyStageFunctionStorage::Instance) .def("register", &PyStageFunctionStorage::Register) .def("get_pycpp_function_info", &PyStageFunctionStorage::GetPyCppFunctionInfo); py::class_(m, "MethodSignature_") .def(py::init<>()) .def_readwrite("servable_name", &MethodSignature::servable_name) .def_readwrite("method_name", &MethodSignature::method_name) .def_readwrite("inputs", &MethodSignature::inputs) .def_readwrite("outputs", &MethodSignature::outputs) .def("add_stage_function", &MethodSignature::AddStageFunction) .def("add_stage_model", &MethodSignature::AddStageModel) .def("set_return", &MethodSignature::SetReturn); py::class_(m, "RequestSpec_") .def(py::init<>()) .def_readwrite("servable_name", &RequestSpec::servable_name) .def_readwrite("version_number", &RequestSpec::version_number) .def_readwrite("method_name", &RequestSpec::method_name); py::class_(m, "CommonModelMeta_") .def(py::init<>()) .def_readwrite("servable_name", &CommonModelMeta::servable_name) .def_readwrite("model_key", &CommonModelMeta::model_key) .def_readwrite("inputs_count", &CommonModelMeta::inputs_count) .def_readwrite("outputs_count", &CommonModelMeta::outputs_count) .def_readwrite("with_batch_dim", &CommonModelMeta::with_batch_dim) .def_readwrite("without_batch_dim_inputs", &CommonModelMeta::without_batch_dim_inputs); py::class_(m, "LocalModelMeta_") .def(py::init<>()) .def_readwrite("model_file", &LocalModelMeta::model_files) .def_readwrite("config_file", &LocalModelMeta::config_file) .def_readwrite("model_context", &LocalModelMeta::model_context) .def("set_model_format", &LocalModelMeta::SetModelFormat); py::class_(m, "ModelContext_") .def(py::init<>()) .def_readwrite("thread_num", &ModelContext::thread_num) .def_readwrite("thread_affinity_core_list", &ModelContext::thread_affinity_core_list) .def_readwrite("enable_parallel", &ModelContext::enable_parallel) .def_readwrite("device_list", &ModelContext::device_list) .def("append_device_info", &ModelContext::AppendDeviceInfo); py::class_(m, "DistributedModelMeta_") .def(py::init<>()) .def_readwrite("rank_size", &DistributedModelMeta::rank_size) .def_readwrite("stage_size", &DistributedModelMeta::stage_size) .def_readwrite("enable_pipeline_infer", &DistributedModelMeta::enable_pipeline_infer); py::class_(m, "ModelMeta_") .def(py::init<>()) .def_readwrite("common_meta", &ModelMeta::common_meta) .def_readwrite("local_meta", &ModelMeta::local_meta) .def_readwrite("distributed_meta", &ModelMeta::distributed_meta); py::class_(m, "ServableSignature_") .def(py::init<>()) .def_readwrite("servable_meta", &ServableSignature::model_metas) .def_readwrite("methods", &ServableSignature::methods); py::class_(m, "ServableRegister_") .def_static("register_model_input_output_info", &PyServableRegister::RegisterInputOutputInfo) .def_static("register_method", &PyServableRegister::RegisterMethod) .def_static("declare_model", &PyServableRegister::DeclareModel) .def_static("declare_distributed_model", &PyServableRegister::DeclareDistributedModel) .def_static("run", &PyServableRegister::Run); py::class_(m, "OneRankConfig_") .def(py::init<>()) .def_readwrite("device_id", &OneRankConfig::device_id) .def_readwrite("ip", &OneRankConfig::ip); py::class_(m, "DistributedServableConfig_") .def(py::init<>()) .def_readwrite("common_meta", &DistributedServableConfig::common_meta) .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) .def_readwrite("rank_list", &DistributedServableConfig::rank_list); } void PyRegMaster(pybind11::module *m_ptr) { auto &m = *m_ptr; py::class_(m, "Master_") .def_static("start_grpc_server", &PyMaster::StartGrpcServer) .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) .def_static("start_restful_server", &PyMaster::StartRestfulServer) .def_static("wait_and_clear", &PyMaster::WaitAndClear) .def_static("stop_and_clear", &PyMaster::StopAndClear) .def_static("only_model_stage", &PyMaster::OnlyModelStage); py::class_>(m, "WorkerContext_") .def_static("init_worker", &WorkerContext::PyInitWorkerContext) .def("has_error_notified", &WorkerContext::HasErrorNotified) .def("has_exit_notified", &WorkerContext::HasExitNotified) .def("get_notified_error", &WorkerContext::GetNotifiedError) .def("ready", &WorkerContext::HasReady) .def("print_status", &WorkerContext::PrintStatus) .def("is_in_starting", &WorkerContext::IsInStarting) .def("update_worker_pid", &WorkerContext::UpdateWorkerPid) .def("notify_not_alive", &WorkerContext::PyNotifyNotAlive) .def("notify_start_failed", &WorkerContext::PyNotifyStartFailed) .def_property_readonly("is_unavailable", &WorkerContext::IsUnavailable) .def_property_readonly("normal_handled_count", &WorkerContext::GetNormalHandledCount) .def_property_readonly("address", &WorkerContext::GetWorkerAddress); py::class_(m, "SSLConfig_") .def(py::init<>()) .def_readwrite("certificate", &SSLConfig::certificate) .def_readwrite("private_key", &SSLConfig::private_key) .def_readwrite("custom_ca", &SSLConfig::custom_ca) .def_readwrite("verify_client", &SSLConfig::verify_client) .def_readwrite("use_ssl", &SSLConfig::use_ssl); } void PyRegWorker(pybind11::module *m_ptr) { auto &m = *m_ptr; py::class_(m, "TaskItem_") .def(py::init<>()) .def_readonly("has_stopped", &TaskItem::has_stopped) .def_property_readonly("method_name", [](const TaskItem &item) { return item.task_info.group_name; }) .def_property_readonly("stage_index", [](const TaskItem &item) { return item.task_info.priority; }) .def_property_readonly("task_name", [](const TaskItem &item) { return item.task_info.task_name; }) .def_property_readonly("instance_list", [](const TaskItem &item) { py::tuple instances(item.instance_list.size()); for (size_t i = 0; i < item.instance_list.size(); i++) { instances[i] = PyTensor::AsNumpyTuple(item.instance_list[i]->data); } return instances; }); py::class_(m, "Worker_") .def_static("start_servable", &PyWorker::StartServable, py::call_guard()) .def_static("start_distributed_servable", &PyWorker::StartDistributedServable, py::call_guard()) .def_static("start_extra_servable", &PyWorker::StartExtraServable, py::call_guard()) .def_static("get_declared_model_names", &PyWorker::GetDeclaredModelNames) .def_static("wait_and_clear", &PyWorker::WaitAndClear) .def_static("stop_and_clear", PyWorker::StopAndClear) .def_static("enable_pytask_que", PyWorker::EnablePyTaskQueue) .def_static("get_py_task", &PyWorker::GetPyTask, py::call_guard()) .def_static("push_pytask_result", &PyWorker::PushPyTaskResult) .def_static("push_pytask_failed", &PyWorker::PushPyTaskFailed) .def_static("push_pytask_system_failed", &PyWorker::PushPyTaskSystemFailed) .def_static("get_device_type", &PyWorker::GetDeviceType) .def_static("support_reuse_device", &PyWorker::SupportReuseDevice) .def_static("notify_failed", &PyWorker::NotifyFailed); py::class_>(m, "ServableContext_") .def(py::init<>()) .def_static("get_instance", &ServableContext::Instance) .def("set_device_type_str", [](ServableContext &context, const std::string &device_type) { auto status = context.SetDeviceTypeStr(device_type); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } }) .def("set_device_id", &ServableContext::SetDeviceId) .def("set_enable_lite", &ServableContext::SetEnableLite); py::class_>(m, "MasterContext_") .def(py::init<>()) .def_static("get_instance", &MasterContext::Instance) .def("set_max_enqueued_requests", &MasterContext::SetMaxEnqueuedRequests); } void PyRegWorkerAgent(pybind11::module *m_ptr) { auto &m = *m_ptr; py::class_(m, "WorkerAgent_") .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) .def_static("wait_and_clear", &PyAgent::WaitAndClear) .def_static("stop_and_clear", &PyAgent::StopAndClear) .def_static("notify_failed", &PyAgent::NotifyFailed) .def_static("startup_notify_exit", &PyAgent::StartupNotifyExit) .def_static("start_agent", &PyAgent::StartAgent); py::class_(m, "AgentStartUpConfig_") .def(py::init<>()) .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) .def_readwrite("device_id", &AgentStartUpConfig::device_id) .def_readwrite("model_file_names", &AgentStartUpConfig::model_file_names) .def_readwrite("group_file_names", &AgentStartUpConfig::group_file_names) .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) .def_readwrite("agent_address", &AgentStartUpConfig::agent_address) .def_readwrite("distributed_address", &AgentStartUpConfig::distributed_address) .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); } class PyExitSignalHandle { public: static void Start() { ExitSignalHandle::Instance().Start(); } static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } }; // cppcheck-suppress syntaxError PYBIND11_MODULE(_mindspore_serving, m) { PyRegServable(&m); PyRegMaster(&m); PyRegWorker(&m); PyRegWorkerAgent(&m); py::class_(m, "ExitSignalHandle_") .def_static("start", &PyExitSignalHandle::Start) .def_static("has_stopped", &PyExitSignalHandle::HasStopped); (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { Server::Instance().Clear(); Worker::GetInstance().Clear(); WorkerAgent::Instance().Clear(); }}); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/tensor_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "python/tensor_py.h" #include #include #include #include #include #include "mindspore_serving/ccsrc/common/tensor.h" namespace mindspore::serving { static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { std::vector strides; strides.reserve(shape.size()); const auto ndim = shape.size(); for (size_t i = 0; i < ndim; ++i) { auto stride = item_size; for (size_t j = i + 1; j < ndim; ++j) { stride *= shape[j]; } strides.push_back(stride); } return strides; } DataType NumpyTensor::GetDataType(const py::buffer_info &buf) { std::set fp_format = {'e', 'f', 'd'}; std::set int_format = {'b', 'h', 'i', 'l', 'q'}; std::set uint_format = {'B', 'H', 'I', 'L', 'Q'}; if (buf.format.size() == 1) { char format = buf.format.front(); if (fp_format.find(format) != fp_format.end()) { constexpr int size_of_fp16 = 2; constexpr int size_of_fp32 = 4; constexpr int size_of_fp64 = 8; switch (buf.itemsize) { case size_of_fp16: return kMSI_Float16; case size_of_fp32: return kMSI_Float32; case size_of_fp64: return kMSI_Float64; } } else if (int_format.find(format) != int_format.end()) { switch (buf.itemsize) { case sizeof(int8_t): return kMSI_Int8; case sizeof(int16_t): return kMSI_Int16; case sizeof(int32_t): return kMSI_Int32; case sizeof(int64_t): return kMSI_Int64; } } else if (uint_format.find(format) != uint_format.end()) { switch (buf.itemsize) { case sizeof(uint8_t): return kMSI_Uint8; case sizeof(uint16_t): return kMSI_Uint16; case sizeof(uint32_t): return kMSI_Uint32; case sizeof(uint64_t): return kMSI_Uint64; } } else if (format == '?') { return kMSI_Bool; } } MSI_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; return kMSI_Unknown; } static std::string GetPyTypeFormat(DataType data_type) { switch (data_type) { case kMSI_Float16: return "e"; case kMSI_Float32: return py::format_descriptor::format(); case kMSI_Float64: return py::format_descriptor::format(); case kMSI_Uint8: return py::format_descriptor::format(); case kMSI_Uint16: return py::format_descriptor::format(); case kMSI_Uint32: return py::format_descriptor::format(); case kMSI_Uint64: return py::format_descriptor::format(); case kMSI_Int8: return py::format_descriptor::format(); case kMSI_Int16: return py::format_descriptor::format(); case kMSI_Int32: return py::format_descriptor::format(); case kMSI_Int64: return py::format_descriptor::format(); case kMSI_Bool: return py::format_descriptor::format(); default: MSI_LOG(WARNING) << "Unsupported DataType " << data_type << "."; return ""; } } static bool IsCContiguous(const py::array &input) { auto flags = static_cast(input.flags()); return (flags & static_cast(pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_)) != 0; } TensorBasePtr PyTensor::MakeTensor(const py::array &input) { // Get input buffer info. py::buffer_info buf = input.request(); // Check data types. auto buf_type = NumpyTensor::GetDataType(buf); if (buf_type == kMSI_Unknown) { MSI_LOG(EXCEPTION) << "Unsupported tensor type!"; } // Convert input array to C contiguous if need. std::unique_ptr tmp_buf; if (!IsCContiguous(input)) { Py_buffer pybuf; if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS) || pybuf.len < 0) { MSI_LOG(EXCEPTION) << "Failed to get buffer from the input!"; } tmp_buf = std::make_unique(static_cast(pybuf.len)); if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { MSI_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; } PyBuffer_Release(&pybuf); buf.ptr = tmp_buf.get(); } // Get tensor shape. std::vector shape(buf.shape.begin(), buf.shape.end()); return std::make_shared(buf_type, shape, buf.ptr, buf.size * buf.itemsize); } /// Creates a Tensor from a numpy array without copy TensorBasePtr PyTensor::MakeTensorNoCopy(const py::array &input) { // Check format. if (!IsCContiguous(input)) { MSI_LOG(EXCEPTION) << "Array should be C contiguous."; } // Get input buffer info. py::buffer_info buf = input.request(); // Get tensor dtype and check it. auto dtype = NumpyTensor::GetDataType(buf); if (dtype == kMSI_Unknown) { MSI_LOG(EXCEPTION) << "Unsupported data type!"; } // Make a tensor with shared data with numpy array. auto tensor_data = std::make_shared(std::move(buf)); return tensor_data; } py::object PyTensor::AsPythonData(const TensorBasePtr &tensor, bool copy) { auto data_numpy = std::dynamic_pointer_cast(tensor); if (data_numpy) { return data_numpy->py_array(); } if (tensor->is_bytes_val_data()) { if (tensor->bytes_data_size() != 1) { return py::array(); } const uint8_t *data = nullptr; size_t bytes_len = 0; tensor->get_bytes_data(0, &data, &bytes_len); if (tensor->data_type() == kMSI_String) { return py::str(reinterpret_cast(data), bytes_len); } std::vector shape{static_cast(bytes_len)}; std::vector strides = GetStrides(shape, static_cast(sizeof(uint8_t))); py::buffer_info info(reinterpret_cast(const_cast(data)), sizeof(uint8_t), py::format_descriptor::format(), 1, shape, strides); if (!copy) { py::object self = py::cast(tensor); return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); } else { return py::array(py::dtype(info), info.shape, info.strides, info.ptr); } } else { const auto &tensor_shape = tensor->shape(); std::vector shape(tensor_shape.begin(), tensor_shape.end()); std::vector strides = GetStrides(shape, static_cast(tensor->itemsize())); py::buffer_info info(reinterpret_cast(const_cast(tensor->data())), static_cast(tensor->itemsize()), GetPyTypeFormat(tensor->data_type()), static_cast(tensor_shape.size()), shape, strides); if (!copy) { py::object self = py::cast(tensor); return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); } else { return py::array(py::dtype(info), info.shape, info.strides, info.ptr); } } } py::tuple PyTensor::AsNumpyTuple(const InstanceData &instance_data) { py::tuple numpy_inputs_tuple(instance_data.size()); for (size_t i = 0; i < instance_data.size(); i++) { // inputs numpy_inputs_tuple[i] = PyTensor::AsPythonData(instance_data[i], false); } return numpy_inputs_tuple; } InstanceData PyTensor::AsInstanceData(const py::tuple &tuple) { InstanceData instance_data; for (auto &item : tuple) { TensorBasePtr tensor = nullptr; if (py::isinstance(item)) { // bytes can be seen as str, so check bytes first tensor = std::make_shared(); tensor->set_data_type(serving::kMSI_Bytes); auto val = std::string(item.cast()); tensor->add_bytes_data(reinterpret_cast(val.data()), val.length()); } else if (py::isinstance(item)) { tensor = std::make_shared(); tensor->set_data_type(serving::kMSI_String); auto val = item.cast(); tensor->add_bytes_data(reinterpret_cast(val.data()), val.length()); } else if (py::isinstance(item)) { auto val = item.cast(); tensor = std::make_shared(serving::kMSI_Bool, std::vector(), &val, sizeof(val)); } else if (py::isinstance(item)) { auto val = item.cast(); tensor = std::make_shared(serving::kMSI_Int64, std::vector(), &val, sizeof(val)); } else if (py::isinstance(item)) { auto val = item.cast(); tensor = std::make_shared(serving::kMSI_Float64, std::vector(), &val, sizeof(val)); } else { try { tensor = PyTensor::MakeTensorNoCopy(py::cast(item)); } catch (const std::runtime_error &error) { MSI_LOG_EXCEPTION << "Get illegal result data with type " << py::str(item.get_type()).cast(); } } instance_data.push_back(tensor); } return instance_data; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/tensor_py.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SERVING_PY_H #define MINDSPORE_SERVING_SERVING_PY_H #include #include #include #include #include #include "common/serving_common.h" #include "common/instance.h" namespace py = pybind11; namespace mindspore::serving { class NumpyTensor : public TensorBase { public: explicit NumpyTensor(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {} ~NumpyTensor() noexcept { py::gil_scoped_acquire acquire; { buffer_ = py::buffer_info(); } } /// py::array object. py::array py_array() const { // Use dummy owner to avoid copy data. py::str dummyOwner; return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner); } void set_data_type(DataType) override { MSI_LOG_EXCEPTION << "NumpyTensor is readyonly, cannot invoke set_data_type"; } DataType data_type() const override { return GetDataType(buffer_); } void set_shape(const std::vector &) override { MSI_LOG_EXCEPTION << "NumpyTensor is readyonly, cannot invoke set_shape"; } std::vector shape() const override { return buffer_.shape; } const uint8_t *data() const override { return static_cast(buffer_.ptr); } size_t data_size() const override { if (buffer_.size <= 0 || buffer_.itemsize <= 0) { return 0; } return static_cast(buffer_.size * buffer_.itemsize); } bool resize_data(size_t) override { MSI_LOG_EXCEPTION << "NumpyTensor is readonly, cannot invoke resize_data"; } uint8_t *mutable_data() override { MSI_LOG_EXCEPTION << "NumpyTensor is readonly, cannot invoke mutable_data"; } void clear_bytes_data() override { MSI_LOG_EXCEPTION << "NumpyTensor is readyonly, cannot invoke clear_bytes_data"; } void add_bytes_data(const uint8_t *, size_t) override { MSI_LOG_EXCEPTION << "NumpyTensor is readyonly, cannot invoke add_bytes_data"; } size_t bytes_data_size() const override { return 0; } void get_bytes_data(size_t, const uint8_t **, size_t *) const override { MSI_LOG_EXCEPTION << "NumpyTensor is readyonly, cannot invoke get_bytes_data"; } static DataType GetDataType(const py::buffer_info &buf); private: py::buffer_info buffer_; }; class PyTensor { public: // For all type, but for BYTES type, there can only be one item in bytes_val. // If the tensor data is destroyed when the numpy array is return to python env, the tensor data need to be copied static py::object AsPythonData(const TensorBasePtr &tensor, bool copy = false); static TensorBasePtr MakeTensor(const py::array &input); static TensorBasePtr MakeTensorNoCopy(const py::array &input); static py::tuple AsNumpyTuple(const InstanceData &instance); static InstanceData AsInstanceData(const py::tuple &tuple); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_SERVING_PY_H ================================================ FILE: mindspore_serving/ccsrc/python/worker/servable_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "python/worker/servable_py.h" #include #include #include #include "worker/servable_register.h" #include "worker/worker.h" namespace mindspore::serving { void PyServableRegister::RegisterMethod(const MethodSignature &method) { auto status = ServableRegister::Instance().RegisterMethod(method); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyServableRegister::DeclareModel(const ModelMeta &servable) { auto status = ServableRegister::Instance().DeclareModel(servable); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyServableRegister::DeclareDistributedModel(const ModelMeta &servable) { auto status = ServableRegister::Instance().DeclareDistributedModel(servable); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyServableRegister::RegisterInputOutputInfo(const std::string &model_key, size_t inputs_count, size_t outputs_count, uint64_t subgraph) { auto status = ServableRegister::Instance().RegisterInputOutputInfo(model_key, inputs_count, outputs_count, subgraph); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } py::tuple PyServableRegister::Run(const std::string &model_key, const py::tuple &args, uint64_t subgraph) { std::stringstream model_stream; if (subgraph == 0) { model_stream << "Model(" << model_key << ").call()"; } else { model_stream << "Model(" << model_key << ", subgraph=" << subgraph << ").call()"; } const std::string model_str = model_stream.str(); RequestSpec request; auto const &signature = ServableRegister::Instance().GetServableSignature(); auto model_meta = signature.GetModelDeclare(model_key); if (model_meta == nullptr) { MSI_LOG_EXCEPTION << model_str << " failed: the model is not declared, ensure that interface 'declare_model' can take effect " "when importing servable_config.py by the serving server"; } auto &common_meta = model_meta->common_meta; auto input_it = common_meta.inputs_count.find(subgraph); if (input_it == common_meta.inputs_count.end()) { MSI_LOG_EXCEPTION << model_str << " failed: The model does not have subgraph of index " << subgraph << ", the subgraph count of the model is " << common_meta.inputs_count.size(); } auto input_count = input_it->second; request.servable_name = ServableRegister::Instance().GetServableSignature().servable_name; request.method_name = ServableRegister::Instance().GetCallModelMethodName(model_key, subgraph); std::vector inputs; auto inputs_args = py::cast(args); for (size_t i = 0; i < inputs_args.size(); i++) { auto input = PyTensor::AsInstanceData(py::cast(inputs_args[i])); if (input.size() != input_count) { MSI_LOG_EXCEPTION << model_str << " failed: The inputs count " << input.size() << " of instance " << i << " is not equal to the inputs count " << input_count << " of the model"; } inputs.push_back(input); } std::vector outs; { py::gil_scoped_release release; auto status = Worker::GetInstance().Run(request, inputs, &outs); if (status != SUCCESS || outs.size() == 0) { MSI_LOG_EXCEPTION << model_str << " failed: " << status.StatusMessage(); } } py::tuple outputs(outs.size()); for (size_t i = 0; i < outs.size(); i++) { auto &out = outs[i]; if (out->error_msg != SUCCESS) { MSI_LOG_EXCEPTION << model_str << " failed: " << out->error_msg.StatusMessage(); } outputs[i] = PyTensor::AsNumpyTuple(out->data); } return outputs; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/worker/servable_py.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_SERVABLE_PY_H #define MINDSPORE_SERVING_WORKER_SERVABLE_PY_H #include #include "common/servable.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" #include "pybind11/stl.h" #include "python/tensor_py.h" namespace py = pybind11; namespace mindspore::serving { class MS_API PyServableRegister { public: static void RegisterMethod(const MethodSignature &method); static void DeclareModel(const ModelMeta &servable); static void DeclareDistributedModel(const ModelMeta &servable); static void RegisterInputOutputInfo(const std::string &model_key, size_t inputs_count, size_t outputs_count, uint64_t subgraph = 0); // input args: list, output: tuple static py::tuple Run(const std::string &model_key, const py::tuple &args, uint64_t subgraph); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_SERVABLE_PY_H ================================================ FILE: mindspore_serving/ccsrc/python/worker/worker_py.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "python/worker/worker_py.h" #include #include #include #include #include "common/exit_handle.h" #include "worker/notfiy_master/grpc_notify.h" #include "worker/local_servable/local_model_loader.h" #include "worker/distributed_worker/distributed_model_loader.h" #include "worker/inference/inference.h" #include "worker/servable_register.h" #include "worker/extra_worker/remote_call_model.h" #include "worker/context.h" namespace mindspore::serving { void PyWorker::StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, const std::string &master_address, const std::string &worker_address, const std::string &dec_key, const std::string &dec_mode) { if (Worker::GetInstance().IsRunning()) { MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; } Worker::GetInstance().StartListeningParentExitThread(); const auto &signature = ServableRegister::Instance().GetServableSignature(); if (signature.servable_name != servable_name) { MSI_LOG_EXCEPTION << "Servable '" << servable_name << "' has not been registered"; } Status status; std::map> models_loader; Worker::GetInstance().SetContinueListenChildren(true); status = LoadLocalModels(servable_directory, servable_name, version_number, dec_key, dec_mode, signature, &models_loader); Worker::GetInstance().SetContinueListenChildren(false); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } status = Worker::GetInstance().StartServable(servable_directory, servable_name, version_number, models_loader, master_address, worker_address, true); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } Status PyWorker::LoadLocalModels(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, const std::string &dec_key, const std::string &dec_mode, const ServableSignature &signature, std::map> *models_loader) { Status status; for (auto &model_meta : signature.model_metas) { auto &model_key = model_meta.common_meta.model_key; auto local_models_loader = std::make_shared(); status = local_models_loader->LoadModel(servable_directory, servable_name, version_number, model_meta, dec_key, dec_mode); if (status != SUCCESS) { local_models_loader->Clear(); return status; } status = local_models_loader->AfterLoadModel(); if (status != SUCCESS) { local_models_loader->Clear(); return status; } (void)models_loader->emplace(model_key, local_models_loader); } return SUCCESS; } void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &distributed_address, const std::string &master_address, const std::string &worker_address, uint32_t wait_agents_time_in_seconds) { if (Worker::GetInstance().IsRunning()) { MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; } Worker::GetInstance().StartListeningParentExitThread(); Status status; auto model_loader = std::make_shared(); status = Worker::GetInstance().StartDistributedGrpcServer(model_loader, distributed_address); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } status = model_loader->LoadModel(servable_name, rank_table_json_file, wait_agents_time_in_seconds); if (status != SUCCESS) { model_loader->Clear(); MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } status = model_loader->AfterLoadModel(); if (status != SUCCESS) { model_loader->Clear(); MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } std::map> models_loader; models_loader[model_loader->GetModelKey()] = model_loader; status = Worker::GetInstance().StartServable(servable_directory, servable_name, version_number, models_loader, master_address, worker_address, true); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } void PyWorker::StartExtraServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, bool device_ids_empty, const std::string &dec_key, const std::string &dec_mode, const std::string &master_address, const std::string &worker_address) { if (Worker::GetInstance().IsRunning()) { MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; } const auto &signature = ServableRegister::Instance().GetServableSignature(); if (signature.servable_name != servable_name) { MSI_LOG_EXCEPTION << "Servable '" << servable_name << "' has not been registered"; } Worker::GetInstance().StartListeningParentExitThread(); auto own_device = false; std::map> model_loaders; Status status; if (!signature.model_metas.empty()) { // if device_type is None, device_ids is empty, and there are models declared, Cpu target should be support auto target_device_type = ServableContext::Instance()->GetDeviceType(); if (target_device_type == kDeviceTypeNotSpecified && device_ids_empty) { auto support_device_type = InferenceLoader::Instance().GetSupportDeviceType(kDeviceTypeCpu, kUnknownType); if (support_device_type == kDeviceTypeNotSpecified) { MSI_LOG_EXCEPTION << "Servable '" << servable_name << "' has models declared by declare_model, but parameter 'device_ids'" << " of ServableStartConfig is not set in Serving startup script when the MindSpore or Lite inference" << " package not support CPU"; } target_device_type = kDeviceTypeCpu; ServableContext::Instance()->SetDeviceType(target_device_type); } if (target_device_type == kDeviceTypeCpu) { own_device = true; status = LoadLocalModels(servable_directory, servable_name, version_number, dec_key, dec_mode, signature, &model_loaders); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } else { status = RemoteCallModel::InitRemote(servable_name, version_number, master_address, &model_loaders); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } } status = Worker::GetInstance().StartServable(servable_directory, servable_name, version_number, model_loaders, master_address, worker_address, own_device); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } } std::vector PyWorker::GetDeclaredModelNames() { std::vector model_names; for (auto &model_meta : ServableRegister::Instance().GetServableSignature().model_metas) { // cppcheck-suppress useStlAlgorithm model_names.push_back(model_meta.common_meta.model_key); } return model_names; } bool PyWorker::EnablePyTaskQueue() { return Worker::GetInstance().GetWorkExecutor().GetPyTaskQueue().IsRunning(); } TaskItem PyWorker::GetPyTask() { TaskItem item; Worker::GetInstance().GetWorkExecutor().GetPyTaskQueue().PyPopTask(&item); return item; } void PyWorker::PushPyTaskResult(const py::tuple &instance_outputs) { MSI_TIME_STAMP_START(PushPyTaskResult) std::vector outputs; ResultInstance instance; instance.data = PyTensor::AsInstanceData(instance_outputs); outputs.push_back(instance); Worker::GetInstance().GetWorkExecutor().GetPyTaskQueue().PyPushTaskResult(outputs); MSI_TIME_STAMP_END(PushPyTaskResult) } void PyWorker::PushPyTaskFailed(int count, const std::string &error_msg) { auto &task_que = Worker::GetInstance().GetWorkExecutor().GetPyTaskQueue(); auto task_info = task_que.GetHandledTaskInfo(); auto status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Call " << task_info.tag << " Failed, method: '" << task_info.group_name << "', stage index(begin with 1): " << task_info.priority << ", error msg: " << error_msg; std::vector results; for (int i = 0; i < count; i++) { ResultInstance result_instance; result_instance.error_msg = status; results.push_back(result_instance); } task_que.PyPushTaskResult(results); } void PyWorker::PushPyTaskSystemFailed(const std::string &error_msg) { auto task_info = Worker::GetInstance().GetWorkExecutor().GetPyTaskQueue().GetHandledTaskInfo(); auto status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Call " << task_info.tag << " Failed, method: '" << task_info.group_name << "', stage index(begin with 1): " << task_info.priority << ", error msg: " << error_msg; Worker::GetInstance().ClearOnSystemFailed(status); } void PyWorker::WaitAndClear() { { py::gil_scoped_release release; ExitSignalHandle::Instance().WorkerWait(); } Worker::GetInstance().Clear(); } void PyWorker::StopAndClear() { ExitSignalHandle::Instance().Stop(); Worker::GetInstance().Clear(); } std::string PyWorker::GetDeviceType(const std::string &target_device_type, bool enable_lite) { DeviceType target = kDeviceTypeNotSpecified; if (target_device_type == "cpu") { target = kDeviceTypeCpu; } else if (target_device_type == "gpu") { target = kDeviceTypeGpu; } else if (target_device_type == "ascend") { target = kDeviceTypeAscend; } ServableContext::Instance()->SetEnableLite(enable_lite); auto device_type = InferenceLoader::Instance().GetSupportDeviceType(target, kUnknownType); if (device_type == kDeviceTypeAscend) { return "Ascend"; } if (device_type == kDeviceTypeGpu) { return "Gpu"; } if (device_type == kDeviceTypeCpu) { return "Cpu"; } return ""; } bool PyWorker::SupportReuseDevice() { return InferenceLoader::Instance().SupportReuseDevice(); } void PyWorker::NotifyFailed(const std::string &master_address, const std::string &error_msg) { GrpcNotifyMaster::NotifyFailed(master_address, error_msg); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/python/worker/worker_py.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_PY_H #define MINDSPORE_SERVING_WORKER_PY_H #include #include #include #include #include "common/serving_common.h" #include "worker/worker.h" #include "worker/task_queue.h" #include "python/tensor_py.h" namespace mindspore::serving { class MS_API PyWorker { public: static void StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, const std::string &master_address, const std::string &worker_address, const std::string &dec_key, const std::string &dec_mode); static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &distributed_address, const std::string &master_address, const std::string &worker_address, uint32_t wait_agents_time_in_seconds); static void StartExtraServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, bool device_ids_empty, const std::string &dec_key, const std::string &dec_mode, const std::string &master_address, const std::string &worker_address); static std::vector GetDeclaredModelNames(); static void WaitAndClear(); static void StopAndClear(); static bool EnablePyTaskQueue(); static TaskItem GetPyTask(); static void PushPyTaskResult(const py::tuple &instance_outputs); static void PushPyTaskFailed(int count, const std::string &error_msg); static void PushPyTaskSystemFailed(const std::string &error_msg); static std::string GetDeviceType(const std::string &target_device_type, bool enable_lite); static bool SupportReuseDevice(); // for grpc notify failed of worker static void NotifyFailed(const std::string &master_address, const std::string &error_msg); private: static Status LoadLocalModels(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, const std::string &dec_key, const std::string &dec_mode, const ServableSignature &signature, std::map> *models_loader); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_PY_H ================================================ FILE: mindspore_serving/ccsrc/worker/context.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/context.h" namespace mindspore::serving { std::shared_ptr ServableContext::Instance() { static std::shared_ptr instance = nullptr; if (instance == nullptr) { instance = std::make_shared(); } return instance; } void ServableContext::SetDeviceType(DeviceType device_type) { device_type_ = device_type; } DeviceType ServableContext::GetDeviceType() const { return device_type_; } void ServableContext::SetDeviceId(uint32_t device_id) { device_id_ = device_id; } uint32_t ServableContext::GetDeviceId() const { return device_id_; } Status ServableContext::SetDeviceTypeStr(const std::string &device_type) { DeviceType type; std::string device_type_lowcase = device_type; for (auto &c : device_type_lowcase) { // cppcheck-suppress useStlAlgorithm if (c >= 'A' && c <= 'Z') { c = c - 'A' + 'a'; } } if (device_type_lowcase == "ascend" || device_type_lowcase == "davinci") { type = kDeviceTypeAscend; } else if (device_type_lowcase == "gpu") { type = kDeviceTypeGpu; } else if (device_type_lowcase == "cpu") { type = kDeviceTypeCpu; } else if (device_type_lowcase == "none") { type = kDeviceTypeNotSpecified; } else { return INFER_STATUS_LOG_ERROR(FAILED) << "Unsupported device type '" << device_type << "', only support 'Ascend', 'GPU', 'CPU' and None, case ignored"; } SetDeviceType(type); return SUCCESS; } void ServableContext::SetEnableLite(bool enable_lite) { enable_lite_ = enable_lite; } bool ServableContext::EnableLite() const { return enable_lite_; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/context.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_CONTEXT_H #define MINDSPORE_SERVING_WORKER_CONTEXT_H #include #include #include #include "common/serving_common.h" #include "worker/inference/inference.h" namespace mindspore::serving { class MS_API ServableContext { public: static std::shared_ptr Instance(); Status SetDeviceTypeStr(const std::string &device_type); void SetDeviceType(DeviceType device_type); DeviceType GetDeviceType() const; void SetDeviceId(uint32_t device_id); uint32_t GetDeviceId() const; void SetEnableLite(bool enable_lite); bool EnableLite() const; private: DeviceType device_type_ = kDeviceTypeNotSpecified; uint32_t device_id_ = 0; bool enable_lite_ = false; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_CONTEXT_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/agent_process/agent_process.h" #include "worker/distributed_worker/worker_agent.h" namespace mindspore { namespace serving { grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; WorkerAgent::Instance().StopAgent(false); return grpc::Status::OK; } grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, proto::DistributedPredictReply *reply) { MSI_LOG(INFO) << "Begin call service Eval"; WorkerAgent::Instance().Run(*request, reply); MSI_LOG(INFO) << "End call service Eval"; return grpc::Status::OK; } grpc::Status MSAgentImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); watcher_->RecvPing(request->address()); return grpc::Status::OK; } grpc::Status MSAgentImpl::Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); watcher_->RecvPong(request->address()); return grpc::Status::OK; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H #define MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H #include #include #include #include #include #include "common/serving_common.h" #include "common/heart_beat.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" namespace mindspore { namespace serving { // Service Implement class MSAgentImpl final : public proto::MSAgent::Service { public: explicit MSAgentImpl(const std::string server_address) { if (!watcher_) { watcher_ = std::make_shared>(server_address); } } grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, proto::DistributedPredictReply *reply) override; grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, proto::DistributedExitReply *reply) override; grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; private: std::shared_ptr> watcher_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/agent_startup.h" #include #include "worker/distributed_worker/notify_distributed/notify_worker.h" #include "common/grpc_server.h" namespace mindspore { namespace serving { WorkerAgentStartUp &WorkerAgentStartUp::Instance() { static WorkerAgentStartUp instance; return instance; } Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &distributed_address) { return GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(distributed_address, &config_); } Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { MSI_EXCEPTION_IF_NULL(config); if (config_.rank_list.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; } *config = config_; return SUCCESS; } Status WorkerAgentStartUp::NotifyFailed(const std::string &distributed_address) { return GrpcNotifyDistributeWorker::NotifyFailed(distributed_address); } void WorkerAgentStartUp::StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip) { GrpcNotifyDistributeWorker::StartupNotifyExit(distributed_address, agent_ip); } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H #define MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H #include #include #include "common/serving_common.h" #include "worker/distributed_worker/common.h" #include "worker/inference/inference.h" namespace mindspore { namespace serving { class MS_API WorkerAgentStartUp { public: static WorkerAgentStartUp &Instance(); // from python, worker_agent.py // start_worker_agent // step1, get agents config from worker Status GetAgentsConfigsFromWorker(const std::string &distributed_address); // step2, invoke from python Status GetDistributedServableConfig(DistributedServableConfig *config); Status NotifyFailed(const std::string &distributed_address); void StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip); private: DistributedServableConfig config_; std::string worker_address_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/common.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H #define MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H #include #include #include #include "common/serving_common.h" #include "worker/inference/inference.h" #include "common/servable.h" namespace mindspore { namespace serving { struct OneRankConfig { std::string ip; uint32_t device_id = 0; }; struct DistributedServableConfig { std::string rank_table_content; std::vector rank_list; CommonModelMeta common_meta; DistributedModelMeta distributed_meta; }; struct AgentStartUpConfig { uint32_t rank_id; uint32_t device_id; std::vector model_file_names; std::vector group_file_names; std::string rank_table_json_file_name; std::string agent_address; std::string distributed_address; uint32_t worker_port; CommonModelMeta common_meta; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/distributed_model_loader.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/distributed_model_loader.h" #include #include #include #include "worker/distributed_worker/notify_agent/notify_agent.h" #include "worker/worker.h" #include "common/exit_handle.h" #include "common/proto_tensor.h" #include "worker/servable_register.h" namespace mindspore { namespace serving { struct DistributedPredictMsg { proto::DistributedPredictReply reply; std::promise promise = std::promise(); Status status = FAILED; std::future future = promise.get_future(); }; DistributedModelLoader::~DistributedModelLoader() { Clear(); } uint64_t DistributedModelLoader::GetGraphNum() const { return graph_num_; } Status DistributedModelLoader::Predict(const std::vector &input, std::vector *output, uint64_t subgraph) { Status status(SUCCESS); if (config_.distributed_meta.enable_pipeline_infer) { std::shared_lock lock{rw_mutex_}; status = PredictInner(input, output, subgraph); } else { std::unique_lock lock{rw_mutex_}; status = PredictInner(input, output, subgraph); } if (status != SUCCESS) { MSI_LOG_ERROR << "Predict error happened, now exit distributed servable"; Worker::GetInstance().StopServable(); } return status; } Status DistributedModelLoader::PredictInner(const std::vector &input, std::vector *output, uint64_t subgraph) { MSI_EXCEPTION_IF_NULL(output); if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } proto::DistributedPredictRequest request; proto::DistributedPredictRequest empty_request; request.set_subgraph(subgraph); request.set_subgraph(subgraph); for (const auto &tensor_ptr : input) { auto tensor = request.add_inputs(); ProtoTensor proto_tensor(tensor); proto_tensor.assign(*tensor_ptr); } auto rank_size = config_.distributed_meta.rank_size; auto stage_size = config_.distributed_meta.stage_size; if (rank_size != agent_spec_map_.size()) { MSI_LOG_EXCEPTION << "agent_spec_map_ size " << agent_spec_map_.size() << " not match rank size " << rank_size; } auto agent_num_per_stage = rank_size / stage_size; auto result_agent_id = rank_size - 1; auto msg_list = std::make_shared>(rank_size); request.set_return_result(false); empty_request.set_return_result(false); std::unique_lock wait_lock(wait_mutex_); for (size_t i = 0; i < rank_size; ++i) { AsyncPredictCallback callback = [msg_list, i](const Status &status) { msg_list->at(i).status = status; msg_list->at(i).promise.set_value(); }; if (i < agent_num_per_stage || all_stage_has_input_) { if (i == result_agent_id) { request.set_return_result(true); } agent_spec_map_[i].notify_agent_->DispatchAsync(request, &msg_list->at(i).reply, callback); } else { if (i == result_agent_id) { empty_request.set_return_result(true); } agent_spec_map_[i].notify_agent_->DispatchAsync(empty_request, &msg_list->at(i).reply, callback); } } wait_lock.unlock(); for (size_t rank_id = 0; rank_id < msg_list->size(); ++rank_id) { auto &predict_msg = msg_list->at(rank_id); auto &future = predict_msg.future; const uint64_t kWaitMaxHundredMs = 10 * 10; // waiting for 10s uint64_t k; for (k = 0; k < kWaitMaxHundredMs; k++) { if (ExitSignalHandle::Instance().HasStopped()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker has stopped"; } // waiting for 100ms if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { break; } } if (k >= kWaitMaxHundredMs) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to wait for result of rank " << rank_id; } auto status = predict_msg.status; if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Error happened on get result of rank " << rank_id << ": " << status.StatusMessage(); } auto &reply = predict_msg.reply; if (reply.has_error_msg() && reply.error_msg().error_code() != 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Error happened on get result of rank " << rank_id << ": " << reply.error_msg().error_msg(); } } auto &reply = msg_list->at(result_agent_id).reply; for (int i = 0; i < reply.outputs_size(); ++i) { auto p = std::make_shared(reply.mutable_outputs(i)); auto tensor_ptr = std::make_shared(p->data_type(), p->shape(), p->data(), p->data_size()); output->push_back(tensor_ptr); } return SUCCESS; } std::vector DistributedModelLoader::GetInputInfos(uint64_t subgraph) const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } auto iter = input_infos_.find(subgraph); if (iter != input_infos_.end()) { return iter->second; } MSI_LOG_EXCEPTION << "subgraph: " << subgraph << " is not existed"; return {}; } std::vector DistributedModelLoader::GetOutputInfos(uint64_t subgraph) const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } auto iter = output_infos_.find(subgraph); if (iter != output_infos_.end()) { return iter->second; } MSI_LOG_EXCEPTION << "subgraph: " << subgraph << " is not existed"; return {}; } uint64_t DistributedModelLoader::GetBatchSize() const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } return batch_size_; } Status DistributedModelLoader::GetDistributedServableConfig(DistributedServableConfig *config) const { if (!config_loaded_) { return INFER_STATUS(FAILED) << "Config not loaded"; } *config = config_; return SUCCESS; } void DistributedModelLoader::SetWaitAgentsPromise(bool flag) { if (!promise_set_flag_.test_and_set()) { agents_promise_.set_value(flag); registered_end_flag_ = true; } } Status DistributedModelLoader::RegisterAgent(const std::vector &agent_specs) { std::unique_lock lock{rw_mutex_}; if (registered_end_flag_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Distributed servable has ended up registration"; } if (agent_specs.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The number of graph cannot be 0"; } if (agent_specs[0].rank_id >= config_.distributed_meta.rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid rank id " << agent_specs[0].rank_id << ", rank size " << config_.distributed_meta.rank_size; } DistributedAgentContext context; auto it = agent_spec_map_.find(agent_specs[0].rank_id); if (it != agent_spec_map_.end()) { MSI_LOG_WARNING << "rank_id " << agent_specs[0].rank_id << " has been registered"; return SUCCESS; } context.agent_spec_ = agent_specs; std::shared_ptr notify_agent = std::make_shared(agent_specs[0].agent_address); context.notify_agent_ = notify_agent; agent_spec_map_[agent_specs[0].rank_id] = context; MSI_LOG_INFO << "Rank " << agent_specs[0].rank_id << " been registered"; if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { SetWaitAgentsPromise(true); } return SUCCESS; } void DistributedModelLoader::Clear() { std::unique_lock lock{rw_mutex_}; for (auto &agent : agent_spec_map_) { agent.second.notify_agent_->Exit(); } agent_spec_map_.clear(); model_loaded_ = false; MSI_LOG_INFO << "End clear distributed servable"; } Status DistributedModelLoader::OnAgentExit() { std::unique_lock lock{rw_mutex_}; MSI_LOG_INFO << "Worker agent notify exit"; SetWaitAgentsPromise(false); model_loaded_ = false; return SUCCESS; } Status DistributedModelLoader::LoadModel(const std::string &servable_name, const std::string &rank_table_json_file, uint64_t wait_agents_time_in_seconds) { if (model_loaded_) { MSI_LOG_EXCEPTION << "Model has loaded"; } rank_table_json_file_ = rank_table_json_file; const ServableSignature &signature = ServableRegister::Instance().GetServableSignature(); if (signature.servable_name != servable_name) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; } if (signature.model_metas.size() != 1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Distributed servable '" << servable_name << "' has not been declared or has been declared more than once, " << "declared number: " << signature.model_metas.size(); } if (signature.servable_type != kServableTypeDistributed) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' is not registered as distributed servable"; } auto &meta = signature.model_metas[0]; model_key_ = meta.common_meta.model_key; config_.common_meta = meta.common_meta; config_.distributed_meta = meta.distributed_meta; auto status = InitConfigOnStartup(rank_table_json_file_); if (status != SUCCESS) { MSI_LOG_ERROR << "Init with rank table on start up failed"; return status; } status = CheckRankConfig(); if (status != SUCCESS) { MSI_LOG_ERROR << "Check rank config failed"; return status; } config_loaded_ = true; status = WaitAgentsReady(wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_ERROR << "Waiting for ready of agents failed"; return status; } status = CheckAgentsInfosAndInitTensorInfos(); if (status != SUCCESS) { MSI_LOG_ERROR << "Check agents infos failed"; return status; } model_loaded_ = true; return SUCCESS; } std::string RealPath(const std::string &path) { // Return absolute path when path is accessible std::string res; char resolved_path[PATH_MAX] = {0}; if (realpath(path.c_str(), resolved_path) != nullptr) { res = resolved_path; } return res; } Status DistributedModelLoader::InitConfigOnStartup(const std::string &rank_table_json_file) { std::string rank_table_json_abs_path = RealPath(rank_table_json_file); if (rank_table_json_abs_path.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "failed to get realpath of:" << rank_table_json_file.c_str(); } MSI_LOG(INFO) << "Begin to parser rank table json file: " << rank_table_json_file.c_str(); std::ifstream json_file(rank_table_json_abs_path); if (!json_file.is_open()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "failed to open rank table file:" << rank_table_json_file.c_str(); } std::stringstream buffer; buffer << json_file.rdbuf(); config_.rank_table_content = buffer.str(); json rank_table_json; try { rank_table_json = nlohmann::json::parse(config_.rank_table_content); } catch (json::parse_error &e) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "parse error:" << e.what(); } catch (json::out_of_range &e) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "out of range:" << e.what(); } catch (json::exception &e) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Json exception:" << e.what(); } if (!rank_table_json.is_object()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << rank_table_json_file.c_str() << " is not json object"; } if (rank_table_json.find("group_list") != rank_table_json.end()) { return ParserRankTableWithGroupList(rank_table_json_file, rank_table_json); } else { return ParserRankTableWithServerList(rank_table_json_file, rank_table_json); } } json DistributedModelLoader::ParserArrayInJson(const json &json_array, const std::string &str) { json temp_array; auto iter = json_array.find(str); if (iter == json_array.end()) { MSI_LOG_ERROR << "Check rank table file failed" << str << "in file is not find"; return temp_array; } if (!iter->is_array()) { MSI_LOG_ERROR << "Check rank table file failed" << str << "in file is not array"; return temp_array; } temp_array = json_array.at(str); return temp_array; } std::string DistributedModelLoader::ParserStringInJson(const json &json_str, const std::string &str) { std::string temp_str; auto iter = json_str.find(str); if (iter == json_str.end()) { MSI_LOG_ERROR << "Check rank table file failed" << str << "in file is not find"; return temp_str; } if (!iter->is_string()) { MSI_LOG_ERROR << "Check rank table file failed" << str << "in file is not string"; return temp_str; } json temp_json_str = json_str.at(str); temp_str = temp_json_str.get(); return temp_str; } Status DistributedModelLoader::ParserRankTableWithGroupList(const std::string &rank_table_json_file, const json &rank_table_json) { MSI_LOG_INFO << "Begin to parser rank table with group list"; auto server_list = ParserArrayInJson(rank_table_json, "group_list"); if (server_list.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "group_list attr is empty in" << rank_table_json_file.c_str(); } size_t rank_id = 0; for (auto &server : server_list) { auto instance_list = ParserArrayInJson(server, "instance_list"); if (instance_list.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "instance_list attr is empty in" << rank_table_json_file.c_str(); } for (auto &instance : instance_list) { auto str_server_id = ParserStringInJson(instance, "server_id"); if (str_server_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "server_id attr is empty in" << rank_table_json_file.c_str(); } OneRankConfig one_rank_config; one_rank_config.ip = str_server_id; auto devices = ParserArrayInJson(instance, "devices"); if (devices.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "devices attr is empty in" << rank_table_json_file.c_str(); } auto str_device_id = ParserStringInJson(devices.at(0), "device_id"); if (str_device_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "device_id attr is empty in" << rank_table_json_file.c_str(); } uint32_t temp_device_id; auto status = ConvertStr2Int(rank_table_json_file, str_device_id, "device_id", &temp_device_id); if (status != SUCCESS) { MSI_LOG_ERROR << "Convert device_id from string to int failed"; return status; } auto str_rank_id = ParserStringInJson(instance, "rank_id"); if (str_rank_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "rank_id attr is empty in" << rank_table_json_file.c_str(); } uint32_t temp_rank_id; status = ConvertStr2Int(rank_table_json_file, str_rank_id, "rank_id", &temp_rank_id); if (status != SUCCESS) { MSI_LOG_ERROR << "Convert rank_id from string to int failed"; return status; } if (rank_id != temp_rank_id) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "device size not match rank_id in" << rank_table_json_file.c_str(); } rank_id++; one_rank_config.device_id = temp_device_id; config_.rank_list.push_back(one_rank_config); } } MSI_LOG(INFO) << "Success parser rank table json file with group list and save to DistributedServableConfig"; return SUCCESS; } Status DistributedModelLoader::ConvertStr2Int(const std::string &rank_table_json_file, const std::string ¶_str, const std::string ¶_key, uint32_t *para_int) const { uint32_t parsed_value = 0; constexpr uint32_t decimal_times = 10; for (auto c : para_str) { if (c < '0' || c > '9') { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << para_key << "attr is invalid argument in" << rank_table_json_file.c_str(); } parsed_value = parsed_value * decimal_times + c - '0'; } if (std::to_string(parsed_value) != para_str) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << para_key << "attr is invalid argument in" << rank_table_json_file.c_str(); } *para_int = parsed_value; return SUCCESS; } Status DistributedModelLoader::ParserRankTableWithServerList(const std::string &rank_table_json_file, const json &rank_table_json) { MSI_LOG_INFO << "Begin to parser rank table with server list"; auto server_list = ParserArrayInJson(rank_table_json, "server_list"); if (server_list.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "server_list attr is empty in" << rank_table_json_file.c_str(); } size_t rank_id = 0; for (auto &server : server_list) { auto server_id = ParserStringInJson(server, "server_id"); if (server_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "server_id attr is empty in" << rank_table_json_file.c_str(); } auto device_list = ParserArrayInJson(server, "device"); if (device_list.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "device attr is empty in" << rank_table_json_file.c_str(); } for (auto &device : device_list) { OneRankConfig one_rank_config; one_rank_config.ip = server_id; auto str_device_id = ParserStringInJson(device, "device_id"); if (str_device_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "device_id attr is empty in" << rank_table_json_file.c_str(); } uint32_t temp_device_id; auto status = ConvertStr2Int(rank_table_json_file, str_device_id, "device_id", &temp_device_id); if (status != SUCCESS) { MSI_LOG_ERROR << "Convert device_id from string to int failed"; return status; } auto str_rank_id = ParserStringInJson(device, "rank_id"); if (str_rank_id.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "rank_id attr is empty in" << rank_table_json_file.c_str(); } uint32_t temp_rank_id; status = ConvertStr2Int(rank_table_json_file, str_rank_id, "rank_id", &temp_rank_id); if (status != SUCCESS) { MSI_LOG_ERROR << "Convert rank_id from string to int failed"; return status; } if (rank_id != temp_rank_id) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "device size not match rank_id in" << rank_table_json_file.c_str(); } rank_id++; one_rank_config.device_id = temp_device_id; config_.rank_list.push_back(one_rank_config); } } MSI_LOG(INFO) << "Success parser rank table json file with server list and save to DistributedServableConfig"; return SUCCESS; } Status DistributedModelLoader::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { MSI_LOG_INFO << "Begin waiting ready of all agents"; auto future = agents_promise_.get_future(); if (wait_agents_time_in_seconds == 0) { wait_agents_time_in_seconds = UINT32_MAX; } const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; uint64_t i; for (i = 0; i < kWaitMaxHundredMs; i++) { // if (ExitSignalHandle::Instance().HasStopped()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker or Agents has stopped"; } // waiting for 100ms if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { auto flag = future.get(); if (!flag) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; } break; } } if (i >= kWaitMaxHundredMs) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size() << ", rank size: " << config_.distributed_meta.rank_size; } MSI_LOG_INFO << "Success waiting ready of all agents"; return SUCCESS; } Status DistributedModelLoader::CompareTensorInfos(const std::vector &lefts, const std::vector &rights) { if (lefts.size() != rights.size()) { return INFER_STATUS(FAILED) << "Size not match, left: " << lefts.size() << ", right: " << rights.size(); } auto tensor_info_as_str = [](const TensorInfo &tensor_info) { Status status = INFER_STATUS(SUCCESS) << "size: " << tensor_info.size << ", data type: " << tensor_info.data_type << ", shape: " << tensor_info.shape; return status.StatusMessage(); }; for (size_t k = 0; k < lefts.size(); k++) { auto &left = lefts[k]; auto &right = rights[k]; if (left.size != right.size || left.shape != right.shape || left.data_type != right.data_type) { return INFER_STATUS(FAILED) << "Index " << k << " tensor not match, left- " << tensor_info_as_str(left) << "; right- " << tensor_info_as_str(right); } } return SUCCESS; } Status DistributedModelLoader::CheckAgentsInfosAndInitTensorInfos() { auto rank_size = config_.distributed_meta.rank_size; auto stage_size = config_.distributed_meta.stage_size; auto parallel_count = rank_size / stage_size; MSI_LOG_INFO << "Check agents infos, rank size :" << rank_size << ", stage size: " << stage_size << ", parallel count(rank size/stage size): " << parallel_count; if (agent_spec_map_.size() != rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Registered agents size " << agent_spec_map_.size() << " not match rank size " << rank_size; } graph_num_ = agent_spec_map_[0].agent_spec_.size(); for (size_t i = 1; i < rank_size; i++) { if (graph_num_ != agent_spec_map_[i].agent_spec_.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The number of graph not match in different agent"; } } batch_size_ = agent_spec_map_[0].agent_spec_[0].batch_size; for (size_t subgraph = 0; subgraph < agent_spec_map_[0].agent_spec_.size(); subgraph++) { input_infos_[subgraph] = agent_spec_map_[0].agent_spec_[subgraph].input_infos; output_infos_[subgraph] = agent_spec_map_[rank_size - 1].agent_spec_[subgraph].output_infos; if (input_infos_[subgraph].empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << 0 << " input count cannot be 0"; } if (output_infos_[subgraph].empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size - 1 << " output count cannot be 0"; } const auto &input_infos = input_infos_[subgraph]; Status status; for (size_t i = 0; i < parallel_count; i++) { auto &agent_spec = agent_spec_map_[i].agent_spec_[subgraph]; status = CompareTensorInfos(agent_spec.input_infos, input_infos); if (status != SUCCESS) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << i << " input infos not match rank 0, subgraph: " << subgraph << ", you can check if the actual stage size of the distributed model matches the stage size declared in " "servable_config.py, details: " << status.StatusMessage(); return status; } } for (size_t i = parallel_count; i < rank_size; i++) { auto &agent_spec = agent_spec_map_[i].agent_spec_[subgraph]; if (agent_spec.input_infos.empty()) { if (all_stage_has_input_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Expect stage 0(other stages have empty inputs) or all stages have same inputs, detect rank " << (i - 1) << " input count is " << agent_spec.input_infos.size() << ", but rank " << i << " input count is 0, subgraph: " << subgraph; } continue; } status = CompareTensorInfos(agent_spec.input_infos, input_infos); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Expect stage 0(other stages have empty inputs) or all stages have same inputs, detect rank " << i - 1 << " and rank " << i << " inputs are different, subgraph: " << subgraph << ", details: " << status.StatusMessage(); } all_stage_has_input_ = true; } for (size_t i = 0; i < rank_size; i += parallel_count) { const auto &first_item = agent_spec_map_[i].agent_spec_[subgraph]; for (size_t k = 0; k < parallel_count && i + k < rank_size; k++) { auto rank_id = i + k; const auto &agent_spec = agent_spec_map_[i + k].agent_spec_[subgraph]; status = CompareTensorInfos(agent_spec.output_infos, first_item.output_infos); if (status != SUCCESS) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_id << " output infos not match rank " << i << ", subgraph: " << subgraph << ", details: " << status.StatusMessage(); return status; } if (agent_spec.batch_size != 0 && agent_spec.batch_size != batch_size_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Expect rank " << rank_id << " batch size " << agent_spec.batch_size << " equal to 0 or rank 0's batch size " << batch_size_ << ", subgraph: " << subgraph; } } } } return SUCCESS; } Status DistributedModelLoader::CheckRankConfig() { auto rank_size = config_.distributed_meta.rank_size; auto stage_size = config_.distributed_meta.stage_size; if (stage_size == 0 || rank_size == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size or stage size cannot be 0, rank size: " << rank_size << ", stage size: " << stage_size; } if (rank_size % stage_size != 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size must be an integral multiple of stage size, rank size: " << rank_size << ", stage size: " << stage_size; } if (config_.rank_list.size() != rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size " << config_.rank_list.size() << " declared in rank table file '" << rank_table_json_file_ << "' not equal to " << rank_size << " declared in servable config"; } auto parallel_count = rank_size / stage_size; constexpr size_t card_count_per_machine = 8; if (stage_size == 1) { std::map> device_map; for (size_t i = 0; i < rank_size; i++) { const auto &item = config_.rank_list[i]; auto &device_id_list = device_map[item.ip]; if (device_id_list.count(item.device_id) > 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " << i << " in device ip " << item.ip; } if (item.device_id >= card_count_per_machine) { return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id cannot larger than 8"; } (void)device_id_list.emplace(item.device_id); } } else { if (rank_size < card_count_per_machine) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size " << rank_size << "must >= card count " << card_count_per_machine << " of one machine when stage size " << stage_size << " > 1"; } for (size_t i = 0; i < rank_size; i += card_count_per_machine) { const auto &first_item = config_.rank_list[i]; for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { auto rank_id = i + k; const auto &item = config_.rank_list[rank_id]; if (k != item.device_id) { return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; } if (first_item.ip != item.ip) { return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id << " to be equal with device ip " << first_item.ip << " of rank " << i; } } } } MSI_LOG_INFO << "Check rank table success, rank size: " << rank_size << ", stage size: " << stage_size << ", parallel count in one stage: " << parallel_count; return SUCCESS; } void DistributedModelLoader::OnAgentFailed() { MSI_LOG_INFO << "Worker agent notify failed"; SetWaitAgentsPromise(false); } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/distributed_model_loader.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H #include #include #include #include #include #include #include #include "mindspore_serving/ccsrc/worker/model_loader_base.h" #include "worker/distributed_worker/common.h" #include "worker/distributed_worker/notify_agent/base_notify_agent.h" using nlohmann::json; namespace mindspore { namespace serving { struct DistributedAgentContext { std::vector agent_spec_; std::shared_ptr notify_agent_ = nullptr; }; class MS_API DistributedModelLoader final : public DirectModelLoaderBase { public: DistributedModelLoader() = default; ~DistributedModelLoader(); // from python, worker.py Status LoadModel(const std::string &servable_name, const std::string &rank_table_json_file, uint64_t wait_agents_time_in_seconds); // invoke from agent Status GetDistributedServableConfig(DistributedServableConfig *config) const; // send model and group // register and unregister agent, agent_spec_list_ Status RegisterAgent(const std::vector &agent_specs); Status OnAgentExit(); // predict, use config_ and agent_spec_list_ Status Predict(const std::vector &input, std::vector *output, uint64_t subgraph = 0) override; std::vector GetInputInfos(uint64_t subgraph = 0) const override; std::vector GetOutputInfos(uint64_t subgraph = 0) const override; uint64_t GetBatchSize() const override; void Clear() override; void OnAgentFailed(); uint64_t GetGraphNum() const override; std::string GetModelKey() const { return model_key_; } private: DistributedServableConfig config_; std::atomic_bool config_loaded_ = false; std::string model_key_; std::atomic_bool model_loaded_ = false; uint64_t graph_num_ = 0; std::shared_mutex rw_mutex_; std::mutex wait_mutex_; std::map agent_spec_map_; std::string rank_table_json_file_; std::map> input_infos_; std::map> output_infos_; uint64_t batch_size_; bool all_stage_has_input_ = false; std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; std::atomic_bool registered_end_flag_ = false; std::promise agents_promise_; Status InitConfigOnStartup(const std::string &rank_table_json_file); Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); Status CheckAgentsInfosAndInitTensorInfos(); Status CompareTensorInfos(const std::vector &lefts, const std::vector &rights); Status CheckRankConfig(); void SetWaitAgentsPromise(bool flag); Status PredictInner(const std::vector &input, std::vector *output, uint64_t subgraph = 0); // agent stubs Status ParserRankTableWithGroupList(const std::string &rank_table_json_file, const json &rank_table_json); Status ParserRankTableWithServerList(const std::string &rank_table_json_file, const json &rank_table_json); json ParserArrayInJson(const json &json_array, const std::string &str); std::string ParserStringInJson(const json &json_str, const std::string &str); Status ConvertStr2Int(const std::string &rank_table_json_file, const std::string ¶_str, const std::string ¶_key, uint32_t *para_int) const; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/distributed_process/distributed_process.h" #include #include "worker/worker.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, proto::AgentRegisterReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); std::vector agent_specs; for (auto &spec : request->agent_spec()) { WorkerAgentSpec agent_spec; agent_spec.agent_address = request->address(); GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); agent_specs.push_back(agent_spec); } if (agent_specs.size() == 0) { MSI_LOG(ERROR) << "Agent Register FAILED, agent_specs size is 0"; } Status status(FAILED); status = servable_->RegisterAgent(agent_specs); if (status != SUCCESS) { MSI_LOG(ERROR) << "Agent Register FAILED"; } watcher_->StartWatch(request->address()); return grpc::Status::OK; } grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, proto::AgentExitReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); if (request->address_choice_case() == proto::AgentExitRequest::kAddress) { watcher_->StopWatch(request->address()); } MSI_LOG_INFO << "Agent exit, address: '" << request->address() << "', agent ip: '" << request->agent_ip() << "'"; servable_->OnAgentExit(); Worker::GetInstance().StopServable(); return grpc::Status::OK; } grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, proto::AgentFailedReply *reply) { if (Worker::GetInstance().IsRunning()) { MSI_LOG_ERROR << "Expect worker should not be running"; Worker::GetInstance().StopServable(); } else { servable_->OnAgentFailed(); } return grpc::Status::OK; } grpc::Status MSDistributedImpl::AgentConfigAcquire(grpc::ServerContext *context, const proto::AgentConfigAcquireRequest *request, proto::AgentConfigAcquireReply *reply) { Status status(FAILED); DistributedServableConfig agent_config; status = servable_->GetDistributedServableConfig(&agent_config); if (status != SUCCESS) { MSI_LOG(ERROR) << "Get distributed servable config failed"; return grpc::Status::CANCELLED; } MSI_LOG(INFO) << "Begin to set DistributedServableConfig info in reply message"; // set reply message:AgentConfigAcquireReply, parameter:rank_table_content reply->set_rank_table_content(agent_config.rank_table_content); // set reply message:AgentConfigAcquireReply, parameter:rank_list auto &agent_rank_list = agent_config.rank_list; for (auto &agent_rank : agent_rank_list) { auto rank_list = reply->add_rank_list(); rank_list->set_ip(agent_rank.ip); rank_list->set_device_id(agent_rank.device_id); } // set reply message:AgentConfigAcquireReply, parameter:common_meta auto reply_common_meta = reply->mutable_common_meta(); reply_common_meta->set_servable_name(agent_config.common_meta.servable_name); reply_common_meta->set_model_key(agent_config.common_meta.model_key); reply_common_meta->set_with_batch_dim(agent_config.common_meta.with_batch_dim); auto &without_batch_dim_inputs_list = agent_config.common_meta.without_batch_dim_inputs; for (auto &without_batch_dim_input : without_batch_dim_inputs_list) { reply_common_meta->add_without_batch_dim_inputs(without_batch_dim_input); } auto &proto_input_count = *(reply_common_meta->mutable_inputs_count()); for (auto &inputs_count : agent_config.common_meta.inputs_count) { proto_input_count[inputs_count.first] = inputs_count.second; } auto &proto_output_count = *(reply_common_meta->mutable_outputs_count()); for (auto &outputs_count : agent_config.common_meta.outputs_count) { proto_output_count[outputs_count.first] = outputs_count.second; } // set reply message:AgentConfigAcquireReply, parameter:distributed_meta auto reply_distributed_meta = reply->mutable_distributed_meta(); reply_distributed_meta->set_rank_size(agent_config.distributed_meta.rank_size); reply_distributed_meta->set_stage_size(agent_config.distributed_meta.stage_size); MSI_LOG(INFO) << "Success to set DistributedServableConfig info in reply message"; return grpc::Status::OK; } grpc::Status MSDistributedImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); watcher_->RecvPing(request->address()); return grpc::Status::OK; } grpc::Status MSDistributedImpl::Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); watcher_->RecvPong(request->address()); return grpc::Status::OK; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H #define MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H #include #include #include #include #include #include "common/serving_common.h" #include "common/heart_beat.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "worker/distributed_worker/distributed_model_loader.h" #include "worker/grpc/worker_process.h" namespace mindspore { namespace serving { // Service Implement class MSDistributedImpl { public: explicit MSDistributedImpl(std::shared_ptr servable, const std::string server_address) : servable_(servable) { if (!watcher_) { watcher_ = std::make_shared>(server_address); } } ~MSDistributedImpl() = default; grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, proto::AgentRegisterReply *reply); grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, proto::AgentExitReply *reply); grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, proto::AgentFailedReply *reply); grpc::Status AgentConfigAcquire(grpc::ServerContext *context, const proto::AgentConfigAcquireRequest *request, proto::AgentConfigAcquireReply *reply); grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply); grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply); private: std::shared_ptr servable_; std::shared_ptr> watcher_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H #include #include #include #include #include #include "common/serving_common.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" #include "worker/grpc/worker_server.h" #include "worker/distributed_worker/distributed_process/distributed_process.h" namespace mindspore { namespace serving { template class DistributedServiceContext : public GrpcAsyncServiceContext { public: DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : GrpcAsyncServiceContext( service_impl, async_service, cq) {} virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; }; // Service Implement class WorkerAgentRegisterContext : public DistributedServiceContext { public: WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentRegisterContext() = default; void StartEnqueueRequest() override { async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->AgentRegister(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::AgentRegisterRequest request_; proto::AgentRegisterReply response_; }; class WorkerAgentExitContext : public DistributedServiceContext { public: WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentExitContext() = default; void StartEnqueueRequest() override { async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->AgentExit(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::AgentExitRequest request_; proto::AgentExitReply response_; }; class WorkerAgentFailedContext : public DistributedServiceContext { public: WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentFailedContext() = default; void StartEnqueueRequest() override { async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->AgentFailed(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::AgentFailedRequest request_; proto::AgentFailedReply response_; }; class WorkerAgentConfigAcquireContext : public DistributedServiceContext { public: WorkerAgentConfigAcquireContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentConfigAcquireContext() = default; void StartEnqueueRequest() override { async_service_->RequestAgentConfigAcquire(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->AgentConfigAcquire(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::AgentConfigAcquireRequest request_; proto::AgentConfigAcquireReply response_; }; class WorkerPingContext : public DistributedServiceContext { public: WorkerPingContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPingContext() = default; void StartEnqueueRequest() override { async_service_->RequestPing(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->Ping(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::PingRequest request_; proto::PingReply response_; }; class WorkerPongContext : public DistributedServiceContext { public: WorkerPongContext(MSDistributedImpl *service_impl, proto::MSDistributedWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPongContext() = default; void StartEnqueueRequest() override { async_service_->RequestPong(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { grpc::Status status = service_impl_->Pong(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::PongRequest request_; proto::PongReply response_; }; class DistributedWorkerGrpcServer : public GrpcAsyncServer { public: DistributedWorkerGrpcServer(std::shared_ptr servable, const std::string server_address) : GrpcAsyncServer(), service_impl_(MSDistributedImpl(servable, server_address)) {} void EnqueueRequests() override { WorkerAgentRegisterContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); WorkerAgentExitContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); WorkerAgentFailedContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); WorkerAgentConfigAcquireContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); WorkerPingContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); WorkerPongContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); } private: MSDistributedImpl service_impl_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H #define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H #include #include #include #include "common/serving_common.h" #include "common/servable.h" #include "proto/ms_agent.pb.h" #include "common/grpc_client.h" namespace mindspore { namespace serving { class MS_API BaseNotifyAgent { public: BaseNotifyAgent() = default; virtual ~BaseNotifyAgent() = default; virtual Status Exit() = 0; virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, AsyncPredictCallback callback) = 0; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/notify_agent/notify_agent.h" #include #include #include #include #include "common/exit_handle.h" #include "common/grpc_server.h" #include "common/grpc_client.h" namespace mindspore { namespace serving { GrpcNotifyAgent::GrpcNotifyAgent(const std::string &agent_address) { agent_address_ = agent_address; std::shared_ptr channel = GrpcServer::CreateChannel(agent_address_); stub_ = proto::MSAgent::NewStub(channel); } GrpcNotifyAgent::~GrpcNotifyAgent() = default; Status GrpcNotifyAgent::Exit() { if (stub_) { proto::DistributedExitRequest request; request.set_address(agent_address_); proto::DistributedExitReply reply; grpc::ClientContext context; const int32_t TIME_OUT = 1; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); auto status = stub_->Exit(&context, request, &reply); if (status.ok()) { MSI_LOG_INFO << "Notify one agent exit success, agent address: " << agent_address_; } else { MSI_LOG_INFO << "Notify one agent exit failed, agent address: " << agent_address_ << ", error: " << status.error_code() << ", " << status.error_message(); } } return SUCCESS; } Status GrpcNotifyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, AsyncPredictCallback callback) { if (!stub_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; } if (!distributed_client_) { distributed_client_ = std::make_unique(); distributed_client_->Start(); } distributed_client_->PredictAsync(request, reply, stub_.get(), callback, agent_address_); return SUCCESS; } // namespace serving } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H #define MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H #include #include #include #include #include "worker/distributed_worker/notify_agent/base_notify_agent.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { class MS_API GrpcNotifyAgent : public BaseNotifyAgent { public: explicit GrpcNotifyAgent(const std::string &worker_address); ~GrpcNotifyAgent() override; Status Exit() override; Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, AsyncPredictCallback callback) override; private: std::string agent_address_; std::shared_ptr stub_ = nullptr; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/notify_distributed/notify_worker.h" #include #include #include #include #include "common/exit_handle.h" #include "common/grpc_server.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_address, const std::string &agent_address) : distributed_address_(distributed_address), agent_address_(agent_address) { auto channel = GrpcServer::CreateChannel(distributed_address_); stub_ = proto::MSDistributedWorker::NewStub(channel); } GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; Status GrpcNotifyDistributeWorker::Register(const std::vector &worker_specs) { const int32_t REGISTER_INTERVAL = 1; MSI_LOG(INFO) << "Register to worker " << distributed_address_ << ", agent address: " << agent_address_; proto::AgentRegisterRequest request; GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request); request.set_address(agent_address_); proto::AgentRegisterReply reply; grpc::ClientContext context; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); context.set_deadline(deadline); grpc::Status status = stub_->AgentRegister(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Register SUCCESS "; return SUCCESS; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register to worker failed, grpc error: " << status.error_code() << ", " << status.error_message(); } Status GrpcNotifyDistributeWorker::Unregister() { if (is_stoped_.load()) { return SUCCESS; } is_stoped_ = true; proto::AgentExitRequest request; request.set_address(agent_address_); proto::AgentExitReply reply; grpc::ClientContext context; const int32_t TIME_OUT = 1; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); grpc::Status status = stub_->AgentExit(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Exit SUCCESS "; return SUCCESS; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; } Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &distributed_address) { auto channel = GrpcServer::CreateChannel(distributed_address); auto stub = proto::MSDistributedWorker::NewStub(channel); grpc::ClientContext context; proto::AgentFailedRequest request; proto::AgentFailedReply reply; grpc::Status status = stub->AgentFailed(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Success to notify failure of agent"; return SUCCESS; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; } void GrpcNotifyDistributeWorker::StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip) { auto channel = GrpcServer::CreateChannel(distributed_address); auto stub = proto::MSDistributedWorker::NewStub(channel); grpc::ClientContext context; proto::AgentExitRequest request; request.set_agent_ip(agent_ip); proto::AgentExitReply reply; grpc::Status status = stub->AgentExit(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Success to notify exit of agent start up process"; } else { MSI_LOG(INFO) << "Failed to notify exit of agent start up process"; } } Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string &distributed_address, DistributedServableConfig *config) { const int32_t REGISTER_TIME_OUT = 60; const int32_t REGISTER_INTERVAL = 1; auto loop = REGISTER_TIME_OUT; while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { auto channel = GrpcServer::CreateChannel(distributed_address); auto stub = proto::MSDistributedWorker::NewStub(channel); grpc::ClientContext context; proto::AgentConfigAcquireRequest request; proto::AgentConfigAcquireReply reply; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); context.set_deadline(deadline); grpc::Status status = stub->AgentConfigAcquire(&context, request, &reply); if (status.ok()) { return ParseAgentConfigAcquireReply(reply, config); } MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); std::this_thread::sleep_for(std::chrono::seconds(REGISTER_INTERVAL)); } if (ExitSignalHandle::Instance().HasStopped()) { return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop get Agents configs from Worker"; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to get Agents configs from Worker, worker is not available."; } Status GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, DistributedServableConfig *config) { MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser"; // parser reply message:AgentConfigAcquireReply, parameter:rank_table_content config->rank_table_content = reply.rank_table_content(); // parser reply message:AgentConfigAcquireReply, parameter:rank_list for (auto &temp_rank : reply.rank_list()) { OneRankConfig ome_rank_config; ome_rank_config.ip = temp_rank.ip(); ome_rank_config.device_id = temp_rank.device_id(); config->rank_list.push_back(ome_rank_config); } // parser reply message:AgentConfigAcquireReply, parameter:common_meta auto &temp_common_meta = reply.common_meta(); config->common_meta.servable_name = temp_common_meta.servable_name(); config->common_meta.model_key = temp_common_meta.model_key(); config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim(); for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) { config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs); } for (auto &count : temp_common_meta.inputs_count()) { config->common_meta.inputs_count[count.first] = count.second; } for (auto &count : temp_common_meta.outputs_count()) { config->common_meta.outputs_count[count.first] = count.second; } // parser reply message:AgentConfigAcquireReply, parameter:distributed_meta auto &temp_distributed_meta = reply.distributed_meta(); config->distributed_meta.rank_size = temp_distributed_meta.rank_size(); config->distributed_meta.stage_size = temp_distributed_meta.stage_size(); MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig"; return SUCCESS; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H #define MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H #include #include #include #include "common/serving_common.h" #include "worker/distributed_worker/common.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" namespace mindspore { namespace serving { class MS_API GrpcNotifyDistributeWorker { public: GrpcNotifyDistributeWorker(const std::string &distributed_address, const std::string &agent_address); ~GrpcNotifyDistributeWorker(); Status Register(const std::vector &agent_specs); Status Unregister(); // from start up, not agent static Status NotifyFailed(const std::string &distributed_address); static Status GetAgentsConfigsFromWorker(const std::string &distributed_address, DistributedServableConfig *config); static void StartupNotifyExit(const std::string &distributed_address, const std::string &agent_ip); private: static Status ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, DistributedServableConfig *config); std::string distributed_address_; std::string agent_address_; std::unique_ptr stub_; std::atomic is_stoped_{false}; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/distributed_worker/worker_agent.h" #include #include #include "worker/distributed_worker/agent_process/agent_process.h" #include "worker/distributed_worker/notify_distributed/notify_worker.h" #include "common/exit_handle.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { WorkerAgent &WorkerAgent::Instance() { static WorkerAgent instance; return instance; } Status WorkerAgent::Clear() { if (notify_worker_) { if (exit_notify_worker_) { notify_worker_->Unregister(); MSI_LOG_INFO << "End unregister to worker"; } notify_worker_ = nullptr; } grpc_server_.Stop(); if (session_ != nullptr) { session_->UnloadModel(); session_ = nullptr; } return SUCCESS; } Status WorkerAgent::StartAgent(const AgentStartUpConfig &config, const std::string &dec_key, const std::string &dec_mode) { session_ = InferenceLoader::Instance().CreateMindSporeInfer(); if (session_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Create MindSpore infer failed"; } Status status; config_ = config; const auto &common_meta = config.common_meta; auto enable_lite = InferenceLoader::Instance().GetEnableLite(); status = session_->LoadModelFromFile(kDeviceTypeAscend, config.device_id, config.model_file_names, kMindIR, common_meta.with_batch_dim, common_meta.without_batch_dim_inputs, ModelContext(), dec_key, dec_mode, {}, enable_lite); if (status != SUCCESS) { MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << common_meta.servable_name << ", rank_id: " << config.rank_id << ", device id: " << config.device_id << ", model file: " << config.model_file_names << ", rank table file: " << config.rank_table_json_file_name << ", group config file: " << config.group_file_names; return status; } status = StartGrpcServer(); if (status != SUCCESS) { MSI_LOG_ERROR << "Start agent grpc server failed, agent address: " << config.agent_address; return status; } status = RegisterAgent(); if (status != SUCCESS) { MSI_LOG_ERROR << "Register agent failed, agent address: " << config.agent_address << ", distributed worker address: " << config.distributed_address; return status; } MSI_LOG_INFO << "Start agent success, servable name: " << common_meta.servable_name << ", rank_id: " << config.rank_id << ", device id: " << config.device_id << ", model file: " << config.model_file_names << ", rank table file: " << config.rank_table_json_file_name << ", group config file: " << config.group_file_names; return SUCCESS; } Status WorkerAgent::StartGrpcServer() { std::string server_address = config_.agent_address; return grpc_server_.Start(std::make_shared(server_address), server_address, gRpcMaxMBMsgSize, "Agent"); } Status WorkerAgent::RegisterAgent() { notify_worker_ = std::make_shared(config_.distributed_address, config_.agent_address); auto graph_num = session_->GetSubGraphNum(); if (graph_num == 0) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "RegisterAgent failed, Agent graph_num error"; } std::vector worker_specs; for (uint64_t i = 0; i < graph_num; i++) { WorkerAgentSpec spec; spec.subgraph = i; spec.agent_address = config_.agent_address; spec.rank_id = config_.rank_id; spec.batch_size = session_->GetBatchSize(i); spec.input_infos = session_->GetInputInfos(i); spec.output_infos = session_->GetOutputInfos(i); worker_specs.push_back(spec); } return notify_worker_->Register(worker_specs); } void WorkerAgent::StopAgent(bool notify_worker) { exit_notify_worker_ = notify_worker; ExitSignalHandle::Instance().Stop(); } class ProtoDistributedPredictRequest : public RequestBase { public: explicit ProtoDistributedPredictRequest(const proto::DistributedPredictRequest &other) : proto_request_(other) { for (int i = 0; i < proto_request_.inputs_size(); i++) { (void)tensor_list_.emplace_back(const_cast(&proto_request_.inputs(i))); } } ~ProtoDistributedPredictRequest() = default; size_t size() const override { return tensor_list_.size(); } const TensorBase *operator[](size_t index) const override { if (index >= tensor_list_.size()) { MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_.size(); } return &tensor_list_[index]; } private: std::vector tensor_list_; const proto::DistributedPredictRequest &proto_request_; }; class ProtoDistributedPredictReply : public ReplyBase { public: explicit ProtoDistributedPredictReply(proto::DistributedPredictReply *other) : proto_reply_(other) {} ~ProtoDistributedPredictReply() = default; size_t size() const override { return tensor_list_.size(); }; TensorBase *operator[](size_t index) override { if (index >= tensor_list_.size()) { MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_.size(); } return &tensor_list_[index]; }; const TensorBase *operator[](size_t index) const override { if (index >= tensor_list_.size()) { MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_.size(); } return &tensor_list_[index]; } TensorBase *add() override { auto tensor = proto_reply_->add_outputs(); ProtoTensor proto_tensor(tensor); tensor_list_.push_back(proto_tensor); return &(tensor_list_.back()); } void clear() override { tensor_list_.clear(); } private: proto::DistributedPredictReply *proto_reply_; std::vector tensor_list_; }; Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { if (session_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded"; } Status status; try { MSI_TIME_STAMP_START(ExecuteModel) ProtoDistributedPredictRequest request_wrap(request); ProtoDistributedPredictReply reply_wrap(reply); status = session_->ExecuteModel(request_wrap, &reply_wrap, request.return_result(), request.subgraph()); MSI_TIME_STAMP_END(ExecuteModel) } catch (const std::bad_alloc &ex) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: malloc memory failed"; } catch (const std::runtime_error &ex) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: runtime error occurred: " << ex.what(); } catch (const std::exception &ex) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: exception occurred: " << ex.what(); } catch (...) { status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: exception occurred"; } if (status != SUCCESS) { reply->Clear(); auto error_msg = reply->mutable_error_msg(); error_msg->set_error_code(status.StatusCode()); error_msg->set_error_msg(status.StatusMessage()); } return status; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_AGENT_H #define MINDSPORE_SERVING_WORKER_AGENT_H #include #include #include #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" #include "common/grpc_server.h" #include "worker/distributed_worker/common.h" #include "worker/distributed_worker/notify_distributed/notify_worker.h" #include "worker/inference/inference.h" namespace mindspore { namespace serving { class MS_API WorkerAgent { public: static WorkerAgent &Instance(); Status Clear(); Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); Status StartAgent(const AgentStartUpConfig &config, const std::string &dec_key, const std::string &dec_mode); void StopAgent(bool notify_worker = true); private: AgentStartUpConfig config_; std::shared_ptr session_ = nullptr; GrpcServer grpc_server_; bool exit_notify_worker_ = true; std::shared_ptr notify_worker_; Status StartGrpcServer(); Status RegisterAgent(); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_AGENT_H ================================================ FILE: mindspore_serving/ccsrc/worker/extra_worker/remote_call_model.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/extra_worker/remote_call_model.h" #include #include #include "worker/notfiy_master/grpc_notify.h" #include "common/proto_tensor.h" #include "worker/worker.h" namespace mindspore::serving { Status RemoteCallModel::InitRemote(const std::string &servable_name, uint32_t version_number, const std::string &master_address, std::map> *models) { MSI_EXCEPTION_IF_NULL(models); proto::GetModelInfoReply reply; auto status = GrpcNotifyMaster::GetModelInfos(master_address, servable_name, version_number, &reply); if (status != SUCCESS) { return status; } if (reply.error_msg().error_code() != 0) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << reply.error_msg().error_msg(); } std::map model_infos; GrpcTensorHelper::ConvertProtoModelInfos(reply.model_infos(), &model_infos); for (auto &model_it : model_infos) { auto &model_name = model_it.first; auto &model_info = model_it.second; auto model_loader = std::make_shared(); (void)models->emplace(model_name, model_loader); status = model_loader->InitModel(model_name, version_number, model_info); if (status != SUCCESS) { for (auto &item : *models) { item.second->Clear(); } return status; } } return SUCCESS; } Status RemoteCallModel::InitModel(const std::string &model_key, uint32_t version_number, const ModelInfo &model_info) { model_key_ = model_key; batch_size_ = model_info.batch_size; if (batch_size_ == 0) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Batch size cannot be 0"; } auto &subgraph_infos = model_info.sub_graph_infos; subgraph_contexts_.resize(subgraph_infos.size()); for (size_t i = 0; i < subgraph_infos.size(); i++) { auto &subgraph_info = subgraph_infos[i]; RemoteCallModelContext &context = subgraph_contexts_[i]; context.model_name = model_key; context.version_number = version_number; context.subgraph = i; context.input_infos = subgraph_info.input_infos; for (auto &tensor_info : subgraph_info.output_infos) { TensorInfoOutput output_info; output_info.tensor_info = tensor_info; context.output_infos.push_back(output_info); } } auto status = InitModelExecuteInfo(); if (status != SUCCESS) { return status; } return SUCCESS; } std::vector RemoteCallModel::GetInputInfos(uint64_t subgraph) const { if (subgraph >= subgraph_contexts_.size()) { MSI_LOG_EXCEPTION << "Cannot find subgraph " << subgraph << " in model " << model_key_; } return subgraph_contexts_[subgraph].input_infos; } std::vector RemoteCallModel::GetOutputInfos(uint64_t subgraph) const { if (subgraph >= subgraph_contexts_.size()) { MSI_LOG_EXCEPTION << "Cannot find subgraph " << subgraph << " in model " << model_key_; } std::vector output_tensors; for (auto &item : subgraph_contexts_[subgraph].output_infos) { // cppcheck-suppress useStlAlgorithm output_tensors.push_back(item.tensor_info); } return output_tensors; } uint64_t RemoteCallModel::GetBatchSize() const { return batch_size_; } uint64_t RemoteCallModel::GetGraphNum() const { return subgraph_contexts_.size(); } void RemoteCallModel::Clear() { subgraph_contexts_.clear(); } Status RemoteCallModel::Predict(const std::vector &inputs, std::vector *outputs, uint64_t subgraph) { auto notify_master = Worker::GetInstance().GetGrpcNotifyMaster(); if (notify_master == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Get notify master failed"; } if (subgraph >= subgraph_contexts_.size()) { MSI_LOG_EXCEPTION << "Cannot find subgraph " << subgraph << " in model " << model_key_; } return notify_master->CallModel(subgraph_contexts_[subgraph], inputs, outputs); } Status RemoteCallModel::InitModelExecuteInfo() { auto pid = getpid(); Status status; constexpr uint32_t cache_times = 3; auto &shared_memory = SharedMemoryAllocator::Instance(); for (auto &subgraph : subgraph_contexts_) { for (size_t i = 0; i < subgraph.input_infos.size(); i++) { auto &tensor_info = subgraph.input_infos[i]; uint64_t size_one_batch = tensor_info.size; if (!tensor_info.is_no_batch_dim) { size_one_batch = size_one_batch / batch_size_; } auto memory_key = model_key_ + "_subgraph" + std::to_string(subgraph.subgraph) + "_input" + std::to_string(i) + "_pid" + std::to_string(pid); uint64_t init_count = batch_size_ * cache_times; status = shared_memory.NewMemoryBuffer(memory_key, size_one_batch, init_count); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Init input shared memory failed, item size: " << size_one_batch << ", initial count: " << init_count; } subgraph.request_memory.push_back(memory_key); } for (size_t i = 0; i < subgraph.output_infos.size(); i++) { auto &output_info = subgraph.output_infos[i]; auto &tensor_info = output_info.tensor_info; if (tensor_info.is_no_batch_dim) { output_info.shape_one_batch = tensor_info.shape; output_info.size_one_batch = tensor_info.size; } else { output_info.shape_one_batch = tensor_info.shape; (void)output_info.shape_one_batch.erase(output_info.shape_one_batch.begin()); // the batch size has been checked in WorkerExecutor output_info.size_one_batch = tensor_info.size / batch_size_; } auto memory_key = model_key_ + "_subgraph" + std::to_string(subgraph.subgraph) + "_output" + std::to_string(i) + "_pid" + std::to_string(pid); uint64_t init_count = batch_size_ * cache_times; status = shared_memory.NewMemoryBuffer(memory_key, output_info.size_one_batch, init_count); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Init output shared memory failed, item size: " << output_info.size_one_batch << ", initial count: " << init_count; } subgraph.reply_memory.push_back(memory_key); } } return SUCCESS; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/extra_worker/remote_call_model.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_REMOTE_CALL_MODEL_H #define MINDSPORE_SERVING_REMOTE_CALL_MODEL_H #include #include #include #include #include "worker/model_loader_base.h" namespace mindspore::serving { struct RemoteCallModelContext { uint32_t version_number; std::string model_name; uint64_t subgraph; std::vector request_memory; std::vector reply_memory; std::vector input_infos; std::vector output_infos; }; class MS_API RemoteCallModel : public ModelLoaderBase { public: static Status InitRemote(const std::string &servable_name, uint32_t version_number, const std::string &master_address, std::map> *models); std::vector GetInputInfos(uint64_t subgraph = 0) const override; std::vector GetOutputInfos(uint64_t subgraph = 0) const override; uint64_t GetBatchSize() const override; uint64_t GetGraphNum() const override; void Clear() override; Status Predict(const std::vector &inputs, std::vector *outputs, uint64_t subgraph = 0) override; Status AfterLoadModel() override { return SUCCESS; } bool OwnDevice() const override { return false; } private: std::string model_key_; uint64_t batch_size_ = 0; std::vector subgraph_contexts_; Status InitModelExecuteInfo(); Status InitModel(const std::string &model_key, uint32_t version_number, const ModelInfo &model_info); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_REMOTE_CALL_MODEL_H ================================================ FILE: mindspore_serving/ccsrc/worker/grpc/worker_process.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/grpc/worker_process.h" #include "worker/worker.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { void MSWorkerImpl::Exit(const proto::ExitRequest *request, proto::ExitReply *reply) { MSI_LOG(INFO) << "Master Exit"; Worker::GetInstance().StopServable(false); } void MSWorkerImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { Status status(WORKER_UNAVAILABLE); try { status = Worker::GetInstance().RunAsync(*request, reply, on_finish); } catch (const std::bad_alloc &ex) { MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; } catch (const std::runtime_error &ex) { MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); } catch (const std::exception &ex) { MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); } catch (...) { MSI_LOG(ERROR) << "Serving Error: exception occurred"; } if (status != SUCCESS) { GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); on_finish(); } } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/grpc/worker_process.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_WORKER_PROCESS_H #define MINDSPORE_SERVING_WORKER_WORKER_PROCESS_H #include #include #include #include #include #include "common/serving_common.h" #include "common/heart_beat.h" #include "common/grpc_client.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "proto/ms_master.pb.h" #include "proto/ms_master.grpc.pb.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { // Service Implement class MSWorkerImpl { public: MSWorkerImpl() = default; ~MSWorkerImpl() = default; void Exit(const proto::ExitRequest *request, proto::ExitReply *reply); void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, const PredictOnFinish &on_finish); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_WORKER_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/worker/grpc/worker_server.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/grpc/worker_server.h" #include #include #include "common/grpc_server.h" namespace mindspore { namespace serving {} // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/grpc/worker_server.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_WORKER_SERVER_H #define MINDSPORE_SERVING_WORKER_WORKER_SERVER_H #include #include #include "common/serving_common.h" #include "proto/ms_worker.pb.h" #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" #include "worker/distributed_worker/distributed_model_loader.h" namespace mindspore { namespace serving { template class WorkerServiceContext : public GrpcAsyncServiceContext { public: WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : GrpcAsyncServiceContext(service_impl, async_service, cq) { } virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; }; class WorkerPredictContext : public WorkerServiceContext { public: WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPredictContext() = default; void StartEnqueueRequest() override { async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { MSI_TIME_STAMP_START(WorkerRequestHandle) auto method_name = request_.servable_spec().method_name(); PredictOnFinish on_finish = [this, method_name, time_start_WorkerRequestHandle]() { responder_.Finish(response_, grpc::Status::OK, this); MSI_TIME_STAMP_END_EXTRA(WorkerRequestHandle, "Method " + method_name) }; service_impl_->PredictAsync(&request_, &response_, on_finish); } private: grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; }; class WorkerExitContext : public WorkerServiceContext { public: WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerExitContext() = default; void StartEnqueueRequest() override { async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { service_impl_->Exit(&request_, &response_); responder_.Finish(response_, grpc::Status::OK, this); } private: grpc::ServerAsyncResponseWriter responder_; proto::ExitRequest request_; proto::ExitReply response_; }; class WorkerGrpcServer : public GrpcAsyncServer { public: WorkerGrpcServer() : GrpcAsyncServer() {} void EnqueueRequests() override { WorkerPredictContext::EnqueueRequest(&service_impl_, &svc_, cq_.get()); } protected: MSWorkerImpl service_impl_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_WORKER_PROCESS_H ================================================ FILE: mindspore_serving/ccsrc/worker/inference/inference.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/inference/inference.h" #include #include "glog/logging.h" #include "worker/context.h" namespace mindspore::serving { namespace { constexpr const char *kMindSporeLibName = "libmindspore.so"; constexpr const char *kMindsporeLiteLibName = "libmindspore-lite.so"; constexpr const char *kServingAscendLibName = "libserving_ascend.so"; } // namespace void ModelContext::AppendDeviceInfo(const DeviceInfo &device_info) { (void)device_list.emplace_back(device_info); } std::string ModelContext::AsString() const { std::map output_map; if (thread_num > -1) { output_map["thread num"] = AsStringHelper::AsString(thread_num); } if (!thread_affinity_core_list.empty()) { output_map["thread affinity core list"] = AsStringHelper::AsString(thread_affinity_core_list); } if (enable_parallel > -1) { output_map["enable parallel"] = AsStringHelper::AsString(enable_parallel); } if (!device_list.empty()) { output_map["device infos"] = AsStringHelper::AsString(device_list); } return AsStringHelper::AsString(output_map); } InferenceLoader::InferenceLoader() {} InferenceLoader::~InferenceLoader() { if (ms_lib_handle_ != nullptr) { (void)dlclose(ms_lib_handle_); ms_lib_handle_ = nullptr; } if (ms_cxx_lib_handle_ != nullptr) { (void)dlclose(ms_cxx_lib_handle_); ms_cxx_lib_handle_ = nullptr; } if (gomp_handler_ != nullptr) { (void)dlclose(gomp_handler_); gomp_handler_ = nullptr; } ms_create_handle_ = nullptr; } InferenceLoader &InferenceLoader::Instance() { static InferenceLoader inference = InferenceLoader(); return inference; } std::shared_ptr InferenceLoader::CreateMindSporeInfer() { Status status; if (ms_lib_handle_ == nullptr) { status = LoadMindSporeModelWrap(); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Load " << kServingAscendLibName << " failed, error msg: " << status.StatusMessage(); } } auto instance = ms_create_handle_(); if (instance == nullptr) { return nullptr; } else { return std::shared_ptr(instance); } } std::vector SplitString(const std::string &s, const std::string &delimiters = ":") { auto pos_left = s.find_first_not_of(delimiters, 0); auto pos_right = s.find_first_of(delimiters, pos_left); std::vector tokens; while (pos_left != std::string::npos) { if (pos_right == std::string::npos) { tokens.push_back(s.substr(pos_left)); break; } tokens.push_back(s.substr(pos_left, pos_right - pos_left)); pos_left = s.find_first_not_of(delimiters, pos_right); pos_right = s.find_first_of(delimiters, pos_left); } return tokens; } Status InferenceLoader::LoadMindSporeModelWrap() { MSI_LOG_INFO << "Start Initialize MindSpore Model Wrap so"; std::vector gomp_list = {"libgomp.so.1"}; for (auto &item : gomp_list) { gomp_handler_ = dlopen(item.c_str(), RTLD_NOW | RTLD_GLOBAL); if (gomp_handler_ != nullptr) { MSI_LOG_INFO << "dlopen libgomp so: " << item << " success"; } } if (gomp_handler_ == nullptr) { MSI_LOG_WARNING << "dlopen libgomp library failed, try dlopen list: " << gomp_list; } auto get_dlerror = []() -> std::string { auto error = dlerror(); if (error == nullptr) { return std::string(); } return error; }; enable_lite_ = ServableContext::Instance()->EnableLite(); auto ld_lib_path = common::GetEnv("LD_LIBRARY_PATH"); MSI_LOG_INFO << "Enable lite: " << enable_lite_ << ", LD_LIBRARY_PATH: " << ld_lib_path; if (enable_lite_) { ms_cxx_lib_handle_ = dlopen(kMindsporeLiteLibName, RTLD_NOW | RTLD_GLOBAL); if (ms_cxx_lib_handle_ == nullptr) { std::string load_error = get_dlerror(); std::string so_no_exist_error = std::string(kMindsporeLiteLibName) + ": cannot open shared object file: No such file or directory"; // libmindspore-lite.so exist but dlopen failed if (load_error.find(so_no_exist_error) == std::string::npos) { return INFER_STATUS_LOG_ERROR(FAILED) << "dlopen libmindspore-lite.so failed, dlopen error: " << load_error; } return INFER_STATUS_LOG_ERROR(FAILED) << "dlopen libmindspore_lite.so failed, if you want to use MindSpore Lite to do the inference, please " "append " "libmindspore-lite.so's path to LD_LIBRARY_PATH env or put it in the dynamic_library search path" << ", dlopen error: " << load_error; } MSI_LOG_INFO << "Load " << kMindsporeLiteLibName << " successful"; } else { if (!ld_lib_path.empty()) { auto ms_search_path_list = SplitString(ld_lib_path, ":"); MSI_LOG_INFO << "Search " << kMindSporeLibName << " directory: " << ms_search_path_list; for (auto &item : ms_search_path_list) { auto lib_path = item + "/" + kMindSporeLibName; if (!common::DirOrFileExist(lib_path)) { continue; } ms_cxx_lib_handle_ = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL); if (ms_cxx_lib_handle_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "dlopen libmindspore.so failed, please check whether the MindSpore " "and Ascend/GPU software package versions match" << ", lib path:" << lib_path << ", dlopen error: " << get_dlerror(); } MSI_LOG_INFO << "Load " << kMindSporeLibName << " in " << item << " successful"; break; } } if (ms_cxx_lib_handle_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to load libmindspore.so, please pip install MindSpore whl package for libmindspore.so"; } } ms_lib_handle_ = dlopen(kServingAscendLibName, RTLD_NOW | RTLD_GLOBAL); if (ms_lib_handle_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "dlopen failed, please check whether the MindSpore and Serving versions match, lib name:" << kServingAscendLibName << ", dlopen error: " << get_dlerror(); } MSI_LOG_INFO << "Load " << kServingAscendLibName << " successful"; ms_create_handle_ = (CreateInferHandle)dlsym(ms_lib_handle_, "ServingCreateInfer"); if (ms_create_handle_ == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "dlsym ServingCreateInfer failed, lib name:" << kServingAscendLibName << ", dlopen error: " << get_dlerror(); } return SUCCESS; } bool InferenceLoader::GetEnableLite() const { return enable_lite_; } DeviceType InferenceLoader::GetSupportDeviceType(DeviceType device_type, ModelType model_type) { auto mindspore_infer = CreateMindSporeInfer(); if (mindspore_infer == nullptr) { MSI_LOG_ERROR << "Create MindSpore infer failed"; return kDeviceTypeNotSpecified; } std::vector check_model_types; if (model_type == kUnknownType) { check_model_types = {kMindIR, kMindIR_Lite, kOM}; } else { check_model_types = {model_type}; } for (auto &model_type_item : check_model_types) { if (device_type == kDeviceTypeNotSpecified) { auto device_list = {kDeviceTypeAscend, kDeviceTypeGpu, kDeviceTypeCpu}; for (auto item : device_list) { if (mindspore_infer->CheckModelSupport(item, model_type_item)) { return item; } } } else { if (mindspore_infer->CheckModelSupport(device_type, model_type_item)) { return device_type; } } } return kDeviceTypeNotSpecified; } bool InferenceLoader::SupportReuseDevice() { auto mindspore_infer = CreateMindSporeInfer(); if (mindspore_infer == nullptr) { MSI_LOG_ERROR << "Create MindSpore infer failed"; return false; } return mindspore_infer->SupportReuseDevice(); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/inference/inference.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_INFERENCE_H #define MINDSPORE_SERVING_WORKER_INFERENCE_H #include #include #include #include #include #include #include #include "common/serving_common.h" namespace mindspore { namespace serving { using DeviceInfo = std::map; enum DeviceType { kDeviceTypeNotSpecified, kDeviceTypeAscend, kDeviceTypeGpu, kDeviceTypeCpu, }; enum ModelType : uint32_t { kMindIR = 0, kAIR = 1, kOM = 2, kONNX = 3, kMindIR_Lite = 4, // insert new data type here kUnknownType = 0xFFFFFFFF }; struct MS_API ModelContext { int32_t thread_num{-1}; // -1: unspecified std::vector thread_affinity_core_list; int enable_parallel{-1}; // -1: unspecified, 0: false, 1: true std::vector device_list; void AppendDeviceInfo(const DeviceInfo &device_info); std::string AsString() const; }; struct TensorInfo { size_t size = 0; // -1: unspecified DataType data_type = kMSI_Unknown; std::vector shape; bool is_no_batch_dim = false; }; struct TensorInfoOutput { TensorInfo tensor_info; size_t size_one_batch = 0; std::vector shape_one_batch; }; static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { switch (device_type) { case kDeviceTypeAscend: stream << "Ascend"; break; case kDeviceTypeGpu: stream << "Gpu"; break; case kDeviceTypeCpu: stream << "Cpu"; break; case kDeviceTypeNotSpecified: stream << "None(Default)"; break; default: stream << "[device type: " << static_cast(device_type) << "]"; break; } return stream; } static inline LogStream &operator<<(LogStream &stream, ModelType model_type) { switch (model_type) { case kMindIR: stream << "MindIR"; break; case kOM: stream << "OM"; break; case kONNX: stream << "ONNX"; break; case kAIR: stream << "AIR"; break; case kMindIR_Lite: stream << "MindIR_Lite"; break; case kUnknownType: default: stream << "[model type: " << static_cast(model_type) << "]"; break; } return stream; } class InferenceBase { public: InferenceBase() = default; virtual ~InferenceBase() = default; virtual Status LoadModelFromFile(DeviceType device_type, uint32_t device_id, const std::vector &file_name, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &dec_key, const std::string &dec_mode, const std::string &config_file, bool enable_lite) = 0; virtual Status UnloadModel() = 0; virtual Status ExecuteModel(const RequestBase &request, ReplyBase *reply, bool return_result, uint64_t subgraph) = 0; virtual Status ExecuteModel(const std::vector &request, std::vector *reply, bool return_result, uint64_t subgraph) = 0; virtual std::vector GetInputInfos(uint64_t subgraph) const = 0; virtual std::vector GetOutputInfos(uint64_t subgraph) const = 0; virtual ssize_t GetBatchSize(uint64_t subgraph) const = 0; virtual bool CheckModelSupport(DeviceType device_type, ModelType model_type) const = 0; virtual uint64_t GetSubGraphNum() const = 0; virtual bool SupportReuseDevice() const = 0; }; class MS_API InferenceLoader { public: InferenceLoader(); ~InferenceLoader(); static InferenceLoader &Instance(); std::shared_ptr CreateMindSporeInfer(); DeviceType GetSupportDeviceType(DeviceType device_type, ModelType model_type); bool SupportReuseDevice(); bool GetEnableLite() const; private: typedef InferenceBase *(*CreateInferHandle)(); void *ms_lib_handle_ = nullptr; void *ms_cxx_lib_handle_ = nullptr; void *gomp_handler_ = nullptr; CreateInferHandle ms_create_handle_ = nullptr; Status LoadMindSporeModelWrap(); bool enable_lite_{false}; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_INFERENCE_H ================================================ FILE: mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/inference/mindspore_model_wrap.h" #include #include #include #include #include namespace mindspore { namespace serving { extern "C" { MS_API InferenceBase *ServingCreateInfer() { auto obj = new MindSporeModelWrap(); return dynamic_cast(obj); } } std::mutex MindSporeModelWrap::infer_mutex_; mindspore::DataType TransInferDataType2ApiTypeId(DataType data_type) { const std::map type2id_map{ {serving::kMSI_Unknown, mindspore::DataType::kTypeUnknown}, {serving::kMSI_Bool, mindspore::DataType::kNumberTypeBool}, {serving::kMSI_Int8, mindspore::DataType::kNumberTypeInt8}, {serving::kMSI_Uint8, mindspore::DataType::kNumberTypeUInt8}, {serving::kMSI_Int16, mindspore::DataType::kNumberTypeInt16}, {serving::kMSI_Uint16, mindspore::DataType::kNumberTypeUInt16}, {serving::kMSI_Int32, mindspore::DataType::kNumberTypeInt32}, {serving::kMSI_Uint32, mindspore::DataType::kNumberTypeUInt32}, {serving::kMSI_Int64, mindspore::DataType::kNumberTypeInt64}, {serving::kMSI_Uint64, mindspore::DataType::kNumberTypeUInt64}, {serving::kMSI_Float16, mindspore::DataType::kNumberTypeFloat16}, {serving::kMSI_Float32, mindspore::DataType::kNumberTypeFloat32}, {serving::kMSI_Float64, mindspore::DataType::kNumberTypeFloat64}, }; auto it = type2id_map.find(data_type); if (it == type2id_map.end()) { MSI_LOG_WARNING << "Unsupported MSI data type " << data_type; return mindspore::DataType::kTypeUnknown; } else { return it->second; } } DataType TransTypeId2InferDataType(mindspore::DataType type_id) { const std::map id2type_map{ {mindspore::DataType::kTypeUnknown, kMSI_Unknown}, {mindspore::DataType::kNumberTypeBool, kMSI_Bool}, {mindspore::DataType::kNumberTypeFloat64, kMSI_Float64}, {mindspore::DataType::kNumberTypeInt8, kMSI_Int8}, {mindspore::DataType::kNumberTypeUInt8, kMSI_Uint8}, {mindspore::DataType::kNumberTypeInt16, kMSI_Int16}, {mindspore::DataType::kNumberTypeUInt16, kMSI_Uint16}, {mindspore::DataType::kNumberTypeInt32, kMSI_Int32}, {mindspore::DataType::kNumberTypeUInt32, kMSI_Uint32}, {mindspore::DataType::kNumberTypeInt64, kMSI_Int64}, {mindspore::DataType::kNumberTypeUInt64, kMSI_Uint64}, {mindspore::DataType::kNumberTypeFloat16, kMSI_Float16}, {mindspore::DataType::kNumberTypeFloat32, kMSI_Float32}, }; auto it = id2type_map.find(type_id); if (it == id2type_map.end()) { MSI_LOG_WARNING << "Unsupported data id " << static_cast(type_id); return kMSI_Unknown; } else { return it->second; } } Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &dec_key, const std::string &dec_mode, const std::string &config_file, bool enable_lite) { char path[PATH_MAX]; std::string current_path = getcwd(path, PATH_MAX); std::string build_dir = current_path + "/models_build_temp/"; (void)mkdir(build_dir.c_str(), S_IRWXU | S_IRWXG); build_dir += "device_" + std::to_string(device_id); (void)mkdir(build_dir.c_str(), S_IRWXU | S_IRWXG); auto error_no = chdir(build_dir.c_str()); if (error_no != 0) { MSI_LOG_WARNING << "Failed to call chdir, target build directory: " << build_dir << ", error no: " << error_no; } Status status; if (enable_lite) { status = LoadLiteModelFromFileInner(device_type, device_id, file_names, model_type, with_batch_dim, without_batch_dim_inputs, model_context, config_file); } else { status = LoadModelFromFileInner(device_type, device_id, file_names, model_type, with_batch_dim, without_batch_dim_inputs, model_context, dec_key, dec_mode, config_file); } error_no = chdir(current_path.c_str()); if (error_no != 0) { MSI_LOG_WARNING << "Failed to call chdir, target directory: " << current_path << ", error no: " << error_no; } return status; } Status MindSporeModelWrap::LoadLiteModelFromFileInner(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &config_file) { auto ms_model_type = GetMsModelType(model_type); if (ms_model_type == mindspore::kUnknownType) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model type " << model_type; } if (file_names.size() != 1) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load model from file failed, Multi subgraph is not support when the backend is lite, file names: " << file_names; } const auto &file_name = file_names[0]; auto model = std::make_shared(); try { auto context = TransformModelContext(device_type, device_id, model_context, true); if (!config_file.empty()) { auto load_status = model->LoadConfig(config_file); if (!load_status.IsOk()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load config file: " << config_file << " failed, error details: " << load_status.ToString(); } } auto status = model->Build(file_name, ms_model_type, context); if (!status.IsOk()) { MSI_LOG_ERROR << "Load model from file failed, model file: " << file_name << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString() << ", build error detail: " << status.ToString(); return Status(FAILED, status.ToString()); } } catch (std::runtime_error &ex) { MSI_LOG_ERROR << "Load model from file failed, model file: " << file_name << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString() << ", build error detail: " << ex.what(); return Status(FAILED, ex.what()); } return SetApiModelInfo(device_type, device_id, {file_name}, model_type, with_batch_dim, without_batch_dim_inputs, model_context, {model}); } Status MindSporeModelWrap::LoadModelFromFileInner(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &dec_key, const std::string &dec_mode, const std::string &config_file) { auto ms_model_type = GetMsModelType(model_type); if (ms_model_type == mindspore::kUnknownType) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model type " << model_type; } std::vector> models; try { std::vector graphs; mindspore::Key key; if (!dec_key.empty()) { auto rt = memcpy_s(key.key, sizeof(key.key), dec_key.data(), dec_key.size()); if (rt != EOK) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load model from file failed, dec key size " << dec_key.size() << " should less than " << key.max_key_len; } key.len = dec_key.size(); } else { key.len = 0; } mindspore::Status ms_status; if (file_names.size() > 1) { ms_status = mindspore::Serialization::Load(file_names, ms_model_type, &graphs, key, dec_mode); } else { (void)graphs.emplace_back(mindspore::Graph()); ms_status = mindspore::Serialization::Load(file_names[0], ms_model_type, &graphs[0], key, dec_mode); } (void)memset_s(key.key, sizeof(key.key), 0, key.max_key_len); if (!ms_status.IsOk()) { MSI_LOG_ERROR << "Load model from file failed, model file: " << file_names << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString() << ", dec mode: " << dec_mode << ", load error detail: " << ms_status.ToString(); return Status(FAILED, ms_status.ToString()); } if (file_names.size() > 1 && graphs.size() != file_names.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load model from file failed, generate graphs size " << graphs.size() << " should equal to " << file_names.size(); } auto context = TransformModelContext(device_type, device_id, model_context, false); for (size_t i = 0; i < file_names.size(); i++) { auto model = std::make_shared(); if (!config_file.empty()) { auto load_status = model->LoadConfig(config_file); if (!load_status.IsOk()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load config file: " << config_file << " failed, error details: " << load_status.ToString(); } } mindspore::Status status; status = model->Build(mindspore::GraphCell(graphs[i]), context); if (!status.IsOk()) { MSI_LOG_ERROR << "Load model from file failed, model file: " << file_names[i] << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString() << ", build error detail: " << status.ToString(); return Status(FAILED, status.ToString()); } models.push_back(model); } } catch (std::runtime_error &ex) { MSI_LOG_ERROR << "Load model from file failed, model file: " << file_names << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString() << ", build error detail: " << ex.what(); return Status(FAILED, ex.what()); } auto ret = SetApiModelInfo(device_type, device_id, file_names, model_type, with_batch_dim, without_batch_dim_inputs, model_context, models); if (ret != SUCCESS) { return ret; } return BuildOnPredict(); } Status MindSporeModelWrap::SetApiModelInfo(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::vector> &models) { uint64_t last_batch_size = 0; common_model_info_.device_type = device_type; common_model_info_.device_id = device_id; common_model_info_.with_batch_dim = with_batch_dim; common_model_info_.without_batch_dim_inputs = without_batch_dim_inputs; for (size_t i = 0; i < file_names.size(); i++) { ApiModelInfo api_model_info; api_model_info.model = models[i]; auto st = GetModelInfos(&api_model_info); if (st != SUCCESS) { return st; } MSI_LOG_INFO << "Print model info, model file: '" << file_names[i] << "', subgraph " << i; MSI_LOG_INFO << "Model input infos: count " << api_model_info.input_tensor_infos.size(); for (auto &item : api_model_info.input_tensor_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } MSI_LOG_INFO << "Model output infos: count " << api_model_info.output_tensor_infos.size(); for (auto &item : api_model_info.output_tensor_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } auto status = CalculateBatchSize(&api_model_info); if (status != SUCCESS) { MSI_LOG_ERROR << "Calculate batch size failed, model file: " << file_names[i] << ", subgraph: " << i; return status; } if (last_batch_size != 0 && last_batch_size != common_model_info_.batch_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Expect batch size to be same, last batch size: " << last_batch_size << ", subgraph " << i << " batch size: " << common_model_info_.batch_size; } last_batch_size = common_model_info_.batch_size; models_.push_back(api_model_info); } MSI_LOG_INFO << "Load model from file success, model file: " << file_names << ", device_type: '" << device_type << "', device_id: " << device_id << ", model type: " << model_type << ", model context: " << model_context.AsString(); return SUCCESS; } Status MindSporeModelWrap::BuildOnPredict() { for (size_t i = 0; i < models_.size(); i++) { auto &inputs_info = models_[i].input_tensor_infos; std::vector request; for (auto &info : inputs_info) { auto tensor = std::make_shared(); tensor->set_data_type(info.data_type); tensor->set_shape(info.shape); tensor->resize_data(info.size); request.push_back(tensor); } std::vector reply; auto ret = ExecuteModel(request, &reply, false, i); if (ret != SUCCESS) { MSI_LOG_ERROR << "Failed to execute model when warmup, subgraph " << i; return ret; } } return SUCCESS; } std::shared_ptr MindSporeModelWrap::TransformAscendModelContext(uint32_t device_id, const DeviceInfo &device_info) { auto context_info = std::make_shared(); context_info->SetDeviceID(device_id); using ContextStrFun = std::function; ContextStrFun set_output_type = [context_info](const std::string &val) { // "FP32", "FP16", "UINT8" if (val == "FP32") { context_info->SetOutputType(mindspore::DataType::kNumberTypeFloat32); } else if (val == "FP16") { context_info->SetOutputType(mindspore::DataType::kNumberTypeFloat16); } else if (val == "UINT8") { context_info->SetOutputType(mindspore::DataType::kNumberTypeUInt8); } else { MSI_LOG_ERROR << "Set model context output type failed, unknown data type " << val; } }; for (auto &item : device_info) { const auto &key = item.first; const auto &value = item.second; if (key == "insert_op_cfg_path") { context_info->SetInsertOpConfigPath(value); } else if (key == "input_format") { context_info->SetInputFormat(value); } else if (key == "input_shape") { context_info->SetInputShape(value); } else if (key == "output_type") { set_output_type(value); } else if (key == "precision_mode") { context_info->SetPrecisionMode(value); } else if (key == "op_select_impl_mode") { context_info->SetOpSelectImplMode(value); } else if (key == "fusion_switch_config_path") { context_info->SetFusionSwitchConfigPath(value); } else if (key == "buffer_optimize_mode") { context_info->SetBufferOptimizeMode(value); } } return context_info; } std::shared_ptr MindSporeModelWrap::TransformNvidiaGPUModelContext(uint32_t device_id, const DeviceInfo &device_info) { auto context_info = std::make_shared(); context_info->SetDeviceID(device_id); for (auto &item : device_info) { const auto &key = item.first; const auto &value = item.second; if (key == "precision_mode") { context_info->SetPrecisionMode(value); context_info->SetEnableFP16(value == "fp16"); } } return context_info; } std::shared_ptr MindSporeModelWrap::TransformCPUModelContext(const DeviceInfo &device_info) { auto context_info = std::make_shared(); for (auto &item : device_info) { const auto &key = item.first; const auto &value = item.second; if (key == "precision_mode") { context_info->SetEnableFP16(value == "fp16"); } } return context_info; } std::string MindSporeModelWrap::DeviceTypeToString(serving::DeviceType device_type) { switch (device_type) { case kDeviceTypeGpu: return "gpu"; case kDeviceTypeCpu: return "cpu"; case kDeviceTypeAscend: return "ascend"; case kDeviceTypeNotSpecified: default: return "not_specified"; } } DeviceInfo MindSporeModelWrap::GetDeviceInfo(const std::vector &device_list, serving::DeviceType device_type) { DeviceInfo device_info; for (auto &item : device_list) { if (item.at("device_type") == DeviceTypeToString(device_type)) { device_info = item; break; } } return device_info; } std::shared_ptr MindSporeModelWrap::TransformModelContext(serving::DeviceType device_type, uint32_t device_id, const ModelContext &model_context, bool enable_lite) { auto context = std::make_shared(); if (model_context.thread_num != -1) { context->SetThreadNum(model_context.thread_num); } if (model_context.enable_parallel != -1) { context->SetEnableParallel(model_context.enable_parallel != 0); } if (!model_context.thread_affinity_core_list.empty()) { context->SetThreadAffinity(model_context.thread_affinity_core_list); } std::shared_ptr context_info = nullptr; auto device_info = GetDeviceInfo(model_context.device_list, device_type); if (device_type == kDeviceTypeAscend) { context_info = TransformAscendModelContext(device_id, device_info); } else if (device_type == kDeviceTypeCpu) { context_info = TransformCPUModelContext(device_info); } else if (device_type == kDeviceTypeGpu) { context_info = TransformNvidiaGPUModelContext(device_id, device_info); } if (context_info != nullptr) { context->MutableDeviceInfo().push_back(context_info); } if (enable_lite && device_type != kDeviceTypeCpu) { auto cpu_device_info = GetDeviceInfo(model_context.device_list, kDeviceTypeCpu); context->MutableDeviceInfo().push_back(TransformCPUModelContext(cpu_device_info)); } return context; } Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { MSI_EXCEPTION_IF_NULL(api_model_info); auto model = api_model_info->model; auto get_tensor_info_from_tensor = [](const mindspore::MSTensor &ms_tensor) { serving::TensorInfo tensor_info; tensor_info.shape = ms_tensor.Shape(); tensor_info.data_type = TransTypeId2InferDataType(ms_tensor.DataType()); tensor_info.size = ms_tensor.DataSize(); if (tensor_info.size == 0) { auto &shape = tensor_info.shape; int64_t elements_nums = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); if (elements_nums <= 0) { MSI_LOG_ERROR << "Invalid tensor shape " << shape; return serving::TensorInfo(); } tensor_info.size = TensorBase::GetTypeSize(tensor_info.data_type) * static_cast(elements_nums); } return tensor_info; }; { // input infos auto input_infos = model->GetInputs(); for (size_t i = 0; i < input_infos.size(); i++) { auto &info = input_infos[i]; auto tensor_info = get_tensor_info_from_tensor(info); if (tensor_info.data_type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown input mindspore data type " << static_cast(info.DataType()); } api_model_info->input_tensor_infos.push_back(tensor_info); api_model_info->input_names.push_back(info.Name()); } } { // output infos auto output_infos = model->GetOutputs(); for (auto &info : output_infos) { auto tensor_info = get_tensor_info_from_tensor(info); if (tensor_info.data_type == kMSI_Unknown) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown output mindspore data type " << static_cast(info.DataType()); } api_model_info->output_tensor_infos.push_back(tensor_info); api_model_info->output_names.push_back(info.Name()); } } return SUCCESS; } Status MindSporeModelWrap::CalculateBatchSize(ApiModelInfo *api_model_info) { auto &input_infos = api_model_info->input_tensor_infos; auto &output_infos = api_model_info->output_tensor_infos; if (!common_model_info_.with_batch_dim) { common_model_info_.batch_size = 1; for (auto &input : input_infos) { input.is_no_batch_dim = true; } for (auto &output : output_infos) { output.is_no_batch_dim = true; } return SUCCESS; } const auto &list = common_model_info_.without_batch_dim_inputs; uint32_t cur_batch_size = 0; for (size_t i = 0; i < input_infos.size(); i++) { auto &input = input_infos[i]; if (std::find(list.begin(), list.end(), i) != list.end()) { input.is_no_batch_dim = true; continue; } if (input.shape.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape of model input " << i << " cannot be empty, " << "when with_batch_dim is true and without_batch_dim_inputs is " << list; } if (input.shape[0] <= 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape of model input " << i << " is invalid, shape: " << input.shape; } if (cur_batch_size == 0) { cur_batch_size = static_cast(input.shape[0]); continue; } if (input.shape[0] != cur_batch_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape " << input.shape << " of model input " << i << " does not match current batch size " << cur_batch_size; } } for (size_t i = 0; i < output_infos.size(); i++) { auto &output = output_infos[i]; if (output.shape.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape of model output " << i << " cannot be empty"; } if (output.shape[0] <= 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape of model output " << i << " is invalid, shape: " << output.shape; } if (cur_batch_size == 0) { cur_batch_size = static_cast(output.shape[0]); continue; } if (output.shape[0] != cur_batch_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "The shape " << output.shape << " of model output " << i << " does not match current batch size " << cur_batch_size; } } if (cur_batch_size == 0) { cur_batch_size = 1; } common_model_info_.batch_size = cur_batch_size; return SUCCESS; } Status MindSporeModelWrap::UnloadModel() { for (auto &iter : models_) { iter.model = nullptr; } return SUCCESS; } Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::ReplyBase *reply, bool return_result, uint64_t subgraph) { MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto input_tensor = request[index]; if (input_tensor == nullptr || input_tensor->data() == nullptr) { MSI_LOG_EXCEPTION << "Input tensor data cannot be nullptr, index " << index; } return mindspore::MSTensor::CreateRefTensor(name, TransInferDataType2ApiTypeId(input_tensor->data_type()), input_tensor->shape(), const_cast(input_tensor->data()), input_tensor->data_size(), false); }; FuncMakeOutTensor func_out = [&reply](const mindspore::MSTensor &result_tensor, DataType data_type, const std::vector &shape) { if (result_tensor.IsDevice()) { MSI_LOG_EXCEPTION << "Can not support device type tensor"; } auto tensor = reply->add(); MSI_EXCEPTION_IF_NULL(tensor); (void)tensor->set_data(result_tensor.Data().get(), result_tensor.DataSize()); tensor->set_data_type(data_type); tensor->set_shape(shape); }; return ExecuteModelCommon(request.size(), func_in, func_out, return_result, subgraph); } Status MindSporeModelWrap::ExecuteModel(const std::vector &request, std::vector *reply, bool return_result, uint64_t subgraph) { if (subgraph >= models_.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Inputs subgraph label error, subgraph label is " << subgraph << ", total graph number is " << models_.size(); } MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto &input_tensor = request[index]; return mindspore::MSTensor::CreateRefTensor(name, TransInferDataType2ApiTypeId(input_tensor->data_type()), input_tensor->shape(), const_cast(input_tensor->data()), input_tensor->data_size(), false); }; FuncMakeOutTensor func_out = [&reply](const mindspore::MSTensor &result_tensor, DataType data_type, const std::vector &shape) { if (result_tensor.IsDevice()) { MSI_LOG_EXCEPTION << "Can not support device type tensor"; } TensorBasePtr tensor = nullptr; // lite backend, output tensor result in all predict if (InferenceLoader::Instance().GetEnableLite()) { tensor = std::make_shared(data_type, shape, result_tensor.Data().get(), result_tensor.DataSize()); } else { tensor = std::make_shared(data_type, shape, result_tensor); } reply->push_back(tensor); }; return ExecuteModelCommon(request.size(), func_in, func_out, return_result, subgraph); } Status MindSporeModelWrap::ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func, bool return_result, uint64_t subgraph) { if (models_[subgraph].model == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded"; } auto &model_info = models_[subgraph]; auto model = model_info.model; auto &input_names = model_info.input_names; auto &output_names = model_info.output_names; if (input_names.size() != request_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Inputs size not match, request inputs size " << request_size << ", model inputs size " << input_names.size(); } std::vector inputs; for (size_t i = 0; i < input_names.size(); i++) { auto tensor = in_func(i, input_names[i]); if (tensor == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to create input " << i << " MSTensor"; } inputs.push_back(*tensor); mindspore::MSTensor::DestroyTensorPtr(tensor); } std::vector outputs; mindspore::Status status; if (SupportMultiThreads()) { status = model->Predict(inputs, &outputs); } else { // vm backend std::unique_lock lock(infer_mutex_); status = model->Predict(inputs, &outputs); } if (!status.IsOk()) { MSI_LOG_ERROR << "Predict failed: " << status.ToString(); return Status(FAILED, "Predict Failed"); } if (outputs.size() != output_names.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Outputs size not match, predict outputs size " << outputs.size() << ", model outputs size " << output_names.size(); } if (return_result) { auto &output_infos = model_info.output_tensor_infos; for (size_t i = 0; i < output_names.size(); i++) { auto &result_tensor = outputs[i]; auto &output_info = output_infos[i]; if (result_tensor.DataSize() != output_info.size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Get output failed, predict output data size " << result_tensor.DataSize() << " not match model info data size " << output_info.size << ", output_name " << output_names[i]; } out_func(result_tensor, output_info.data_type, output_info.shape); } } return SUCCESS; } std::vector MindSporeModelWrap::GetInputInfos(uint64_t subgraph) const { return models_[subgraph].input_tensor_infos; } std::vector MindSporeModelWrap::GetOutputInfos(uint64_t subgraph) const { return models_[subgraph].output_tensor_infos; } ssize_t MindSporeModelWrap::GetBatchSize(uint64_t) const { return common_model_info_.batch_size; } uint64_t MindSporeModelWrap::GetSubGraphNum() const { return models_.size(); } bool MindSporeModelWrap::SupportReuseDevice() const { static bool support_reuse_device = false; static bool value_set = false; if (!value_set) { value_set = true; auto is_device_910 = mindspore::Model::CheckModelSupport(mindspore::kAscend910, mindspore::kMindIR); support_reuse_device = !is_device_910; } return support_reuse_device; } bool MindSporeModelWrap::SupportMultiThreads() const { static bool support_multi_thread = false; static bool value_set = false; if (!value_set) { value_set = true; if (InferenceLoader::Instance().GetEnableLite()) { support_multi_thread = true; } else if (mindspore::Model::CheckModelSupport(mindspore::kAscend910, mindspore::kMindIR)) { support_multi_thread = false; } else if (mindspore::Model::CheckModelSupport(mindspore::kGPU, mindspore::kMindIR)) { support_multi_thread = false; } else { support_multi_thread = true; } } return support_multi_thread; } bool MindSporeModelWrap::CheckModelSupport(DeviceType device_type, ModelType model_type) const { auto ms_device_type = GetMsDeviceType(device_type); if (ms_device_type == mindspore::kInvalidDeviceType) { return false; } auto ms_model_type = GetMsModelType(model_type); if (ms_model_type == mindspore::kUnknownType) { return false; } return mindspore::Model::CheckModelSupport(ms_device_type, ms_model_type); } mindspore::ModelType MindSporeModelWrap::GetMsModelType(serving::ModelType model_type) { mindspore::ModelType ms_model_type; switch (model_type) { case kMindIR: ms_model_type = mindspore::kMindIR; break; case kMindIR_Lite: ms_model_type = mindspore::kMindIR_Lite; break; case kAIR: ms_model_type = mindspore::kAIR; break; case kOM: ms_model_type = mindspore::kOM; break; case kONNX: ms_model_type = mindspore::kONNX; break; case kUnknownType: default: ms_model_type = mindspore::kUnknownType; } return ms_model_type; } mindspore::DeviceType MindSporeModelWrap::GetMsDeviceType(serving::DeviceType device_type) { mindspore::DeviceType ms_device_type = mindspore::DeviceType::kInvalidDeviceType; switch (device_type) { case kDeviceTypeAscend: ms_device_type = mindspore::DeviceType::kAscend; break; case kDeviceTypeGpu: ms_device_type = mindspore::DeviceType::kGPU; break; case kDeviceTypeCpu: ms_device_type = mindspore::DeviceType::kCPU; break; case kDeviceTypeNotSpecified: default: break; } return ms_device_type; } ApiBufferTensorWrap::ApiBufferTensorWrap() = default; ApiBufferTensorWrap::ApiBufferTensorWrap(DataType type, const std::vector &shape, const mindspore::MSTensor &tensor) : type_(type), shape_(shape), tensor_(tensor) {} ApiBufferTensorWrap::~ApiBufferTensorWrap() = default; } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WROERK_MODEL_WRAP_H #define MINDSPORE_SERVING_WROERK_MODEL_WRAP_H #include #include #include #include #include #include #include "common/serving_common.h" #include "worker/inference/inference.h" #include "include/api/model.h" #include "include/api/types.h" #include "include/api/data_type.h" #include "include/api/serialization.h" #include "include/api/context.h" namespace mindspore { namespace serving { struct ApiModelInfo { std::vector input_names; std::vector input_tensor_infos; std::vector output_names; std::vector output_tensor_infos; std::shared_ptr model = nullptr; }; struct ApiCommonModelInfo { uint32_t batch_size = 0; serving::DeviceType device_type; uint32_t device_id = 0; bool with_batch_dim = false; std::vector without_batch_dim_inputs; }; class MindSporeModelWrap : public InferenceBase { public: MindSporeModelWrap() = default; ~MindSporeModelWrap() = default; Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &dec_key, const std::string &dec_mode, const std::string &config_file, bool enable_lite) override; Status UnloadModel() override; Status ExecuteModel(const RequestBase &request, ReplyBase *reply, bool return_result, uint64_t subgraph) override; Status ExecuteModel(const std::vector &request, std::vector *reply, bool return_result, uint64_t subgraph) override; std::vector GetInputInfos(uint64_t subgraph) const override; std::vector GetOutputInfos(uint64_t subgraph) const override; ssize_t GetBatchSize(uint64_t subgraph) const override; bool CheckModelSupport(DeviceType device_type, ModelType model_type) const override; uint64_t GetSubGraphNum() const override; bool SupportReuseDevice() const override; bool SupportMultiThreads() const; private: ApiCommonModelInfo common_model_info_; std::vector models_; static std::mutex infer_mutex_; using FuncMakeInBuffer = std::function; using FuncMakeOutTensor = std::function &shape)>; Status ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func, bool return_result, uint64_t subgraph); Status GetModelInfos(ApiModelInfo *model_info); Status SetApiModelInfo(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::vector> &models); Status LoadLiteModelFromFileInner(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &config_file); Status LoadModelFromFileInner(serving::DeviceType device_type, uint32_t device_id, const std::vector &file_names, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, const ModelContext &model_context, const std::string &dec_key, const std::string &dec_mode, const std::string &config_file); std::shared_ptr TransformModelContext(serving::DeviceType device_type, uint32_t device_id, const ModelContext &model_context, bool enable_lite); std::shared_ptr TransformAscendModelContext(uint32_t device_id, const DeviceInfo &device_info); std::shared_ptr TransformNvidiaGPUModelContext(uint32_t device_id, const DeviceInfo &device_info); std::shared_ptr TransformCPUModelContext(const DeviceInfo &device_info); DeviceInfo GetDeviceInfo(const std::vector &device_list, serving::DeviceType device_type); Status BuildOnPredict(); Status CalculateBatchSize(ApiModelInfo *api_model_info); static mindspore::ModelType GetMsModelType(serving::ModelType model_type); static mindspore::DeviceType GetMsDeviceType(serving::DeviceType device_type); static std::string DeviceTypeToString(serving::DeviceType device_type); }; class ApiBufferTensorWrap : public TensorBase { public: ApiBufferTensorWrap(); ApiBufferTensorWrap(DataType type, const std::vector &shape, const mindspore::MSTensor &buffer); ~ApiBufferTensorWrap() override; void set_data_type(DataType type) override { type_ = type; } DataType data_type() const override { return type_; } void set_shape(const std::vector &shape) override { shape_ = shape; } std::vector shape() const override { return shape_; } const uint8_t *data() const override { return static_cast(tensor_.Data().get()); } size_t data_size() const override { return tensor_.DataSize(); } bool resize_data(size_t) override { MSI_LOG_EXCEPTION << "ApiBufferTensorWrap not support resize data"; } uint8_t *mutable_data() override { return static_cast(tensor_.MutableData()); } // For kMSI_String and kMSI_Bytes void clear_bytes_data() override { MSI_LOG_EXCEPTION << "Not support for mindspore::Buffer Tensor"; } void add_bytes_data(const uint8_t *, size_t) override { MSI_LOG_EXCEPTION << "Not support for mindspore::MSTensor Tensor"; } size_t bytes_data_size() const override { MSI_LOG_EXCEPTION << "Not support for mindspore::Buffer Tensor"; } void get_bytes_data(size_t, const uint8_t **, size_t *) const override { MSI_LOG_EXCEPTION << "Not support for mindspore::MSTensor Tensor"; } private: DataType type_ = kMSI_Unknown; std::vector shape_; mindspore::MSTensor tensor_; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WROERK_MODEL_WRAP_H ================================================ FILE: mindspore_serving/ccsrc/worker/local_servable/local_model_loader.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/local_servable/local_model_loader.h" #include #include #include "common/tensor.h" #include "worker/context.h" #include "worker/servable_register.h" namespace mindspore::serving { LocalModelLoader::~LocalModelLoader() noexcept { Clear(); } uint64_t LocalModelLoader::GetGraphNum() const { if (!model_session_) { MSI_LOG_EXCEPTION << "Model '" << GetModelKey() << "' has not been loaded"; } return graph_num_; } Status LocalModelLoader::Predict(const std::vector &input, std::vector *output, uint64_t subgraph) { if (!model_session_) { MSI_LOG_EXCEPTION << "Model '" << GetModelKey() << "' has not been loaded"; } return model_session_->ExecuteModel(input, output, true, subgraph); } std::vector LocalModelLoader::GetInputInfos(uint64_t subgraph) const { if (!model_session_) { MSI_LOG_EXCEPTION << "Model '" << GetModelKey() << "' has not been loaded"; } return model_session_->GetInputInfos(subgraph); } std::vector LocalModelLoader::GetOutputInfos(uint64_t subgraph) const { if (!model_session_) { MSI_LOG_EXCEPTION << "Model '" << GetModelKey() << "' has not been loaded"; } return model_session_->GetOutputInfos(subgraph); } uint64_t LocalModelLoader::GetBatchSize() const { if (!model_session_) { MSI_LOG_EXCEPTION << "Model '" << GetModelKey() << "' has not been loaded"; } auto batch_size = model_session_->GetBatchSize(0); if (batch_size < 0) { MSI_LOG_EXCEPTION << "Invalid batch size " << batch_size << ", model: '" << GetModelKey() << "'"; } return static_cast(batch_size); } Status LocalModelLoader::LoadModel(const std::string &servable_directory, const std::string &servable_name, uint64_t version_number, const ModelMeta &model_meta, const std::string &dec_key, const std::string &dec_mode) { if (model_loaded_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model has loaded"; } base_spec_.servable_directory = servable_directory; base_spec_.servable_name = servable_name; base_spec_.version_number = version_number; model_meta_ = model_meta; Status status; const ServableSignature &signature = ServableRegister::Instance().GetServableSignature(); if (signature.servable_name != servable_name) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; } if (signature.servable_type != kServableTypeLocal) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' is not registered as local servable"; } status = InitDevice(model_meta.local_meta.model_format); if (status != SUCCESS) { MSI_LOG_ERROR << "Init env failed"; return status; } status = LoadModel(version_number, dec_key, dec_mode); if (status != SUCCESS) { return status; } model_loaded_ = true; return SUCCESS; } Status LocalModelLoader::InitDevice(ModelType model_type) { auto context = ServableContext::Instance(); auto device_type = context->GetDeviceType(); auto lite_backend = InferenceLoader::Instance().GetEnableLite(); auto support_device_type = InferenceLoader::Instance().GetSupportDeviceType(device_type, model_type); if (support_device_type == kDeviceTypeNotSpecified || (lite_backend && model_type != kMindIR_Lite && model_type != kMindIR)) { std::string inference_package = lite_backend ? "MindSpore Lite" : "MindSpore"; return INFER_STATUS_LOG_ERROR(FAILED) << "Not support device type " << device_type << " and model type " << model_type << ". Current inference backend: " << inference_package << ". When the inference backend is MindSpore, Ascend 910 and GPU supports MindIR " << "model. When the inference backend is MindSpore Lite, " << "Ascend 310/310P, GPU and CPU support MindIR and MindIR_Lite model converted by Lite converter tool."; } context->SetDeviceType(support_device_type); return SUCCESS; } Status LocalModelLoader::LoadModel(uint64_t version_number, const std::string &dec_key, const std::string &dec_mode) { const auto &model_meta = model_meta_; auto context = ServableContext::Instance(); std::string model_dir = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" + std::to_string(version_number); if (!common::DirOrFileExist(model_dir)) { return INFER_STATUS_LOG_ERROR(FAILED) << "Start servable failed: There is no specified version directory of models, specified version number: " << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" << base_spec_.servable_name << "'"; } const auto &common_meta = model_meta.common_meta; const auto &local_meta = model_meta.local_meta; std::vector model_file_names; for (auto &file : local_meta.model_files) { std::string model_file_name = model_dir + "/" + file; model_file_names.push_back(model_file_name); } auto session = InferenceLoader::Instance().CreateMindSporeInfer(); if (session == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Create MindSpore infer failed"; } std::string config_file_path; if (!local_meta.config_file.empty()) { if (local_meta.config_file[0] == '/') { config_file_path = local_meta.config_file; } else { config_file_path = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" + local_meta.config_file; } } auto enable_lite = InferenceLoader::Instance().GetEnableLite(); Status status = session->LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_names, local_meta.model_format, common_meta.with_batch_dim, common_meta.without_batch_dim_inputs, model_meta.local_meta.model_context, dec_key, dec_mode, config_file_path, enable_lite); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '" << base_spec_.servable_name << "', model file: '" << local_meta.model_files << "', version number " << version_number << ",model context: " << local_meta.model_context.AsString() << ", load error details: " << status.StatusMessage(); } model_session_ = session; graph_num_ = model_file_names.size(); MSI_LOG_INFO << "Load model success, servable directory: '" << base_spec_.servable_directory << "', servable name: '" << base_spec_.servable_name << "', model file: '" << local_meta.model_files << "', version number " << version_number << ", context " << local_meta.model_context.AsString(); return SUCCESS; } void LocalModelLoader::Clear() { if (model_session_ != nullptr) { (void)model_session_->UnloadModel(); model_session_ = nullptr; } model_loaded_ = false; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/local_servable/local_model_loader.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H #define MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H #include #include #include #include #include "common/serving_common.h" #include "common/instance.h" #include "common/servable.h" #include "mindspore_serving/ccsrc/worker/model_loader_base.h" #include "worker/inference/inference.h" namespace mindspore::serving { class MS_API LocalModelLoader final : public DirectModelLoaderBase { public: LocalModelLoader() = default; ~LocalModelLoader() noexcept override; Status Predict(const std::vector &input, std::vector *output, uint64_t subgraph) override; std::vector GetInputInfos(uint64_t subgraph) const override; std::vector GetOutputInfos(uint64_t subgraph) const override; uint64_t GetBatchSize() const override; uint64_t GetGraphNum() const override; Status LoadModel(const std::string &servable_directory, const std::string &servable_name, uint64_t version_number, const ModelMeta &model_meta, const std::string &dec_key, const std::string &dec_mode); Status InitDevice(ModelType model_type); void Clear() override; std::string GetModelKey() const { return model_meta_.common_meta.model_key; } private: ServableLoadSpec base_spec_; ModelMeta model_meta_; uint64_t graph_num_ = 0; std::shared_ptr model_session_ = nullptr; bool model_loaded_ = false; Status LoadModel(uint64_t version, const std::string &dec_key, const std::string &dec_mode); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H ================================================ FILE: mindspore_serving/ccsrc/worker/model_loader_base.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/model_loader_base.h" #include "common/buffer_tensor.h" namespace mindspore::serving { Status DirectModelLoaderBase::Predict(const std::vector &inputs, std::vector *outputs, uint64_t subgraph) { MSI_EXCEPTION_IF_NULL(outputs); if (subgraph >= model_info_.sub_graph_infos.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid input subgraph index " << subgraph << ", model info: " << model_key_ << ", subgraph count: " << model_info_.sub_graph_infos.size(); } Status status; std::vector predict_outputs; auto &subgraph_info = model_info_.sub_graph_infos[subgraph]; status = PrePredict(subgraph_info, model_info_.batch_size, inputs); if (status != SUCCESS) { MSI_LOG_ERROR << "Call Pre Predict failed, model info " << model_key_; return status; } status = Predict(subgraph_info.input_buffers, &predict_outputs, subgraph); if (status != SUCCESS) { MSI_LOG_ERROR << "Predict failed, model info " << model_key_; return status; } status = PostPredict(subgraph_info, model_info_.batch_size, inputs, predict_outputs, outputs); if (status != SUCCESS) { MSI_LOG_ERROR << "Call Post Predict failed, model info " << model_key_; return status; } return SUCCESS; } Status DirectModelLoaderBase::PrePredict(const ModelExecutorSubgraphInfo &subgraph_info, uint64_t model_batch_size, const std::vector &instances) { auto input_batch_size = instances.size(); if (input_batch_size == 0 || input_batch_size > model_batch_size) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Invalid input batch size " << input_batch_size << ", model batch size " << model_batch_size; } auto &input_infos = subgraph_info.input_infos; auto &input_buffers = subgraph_info.input_buffers; for (size_t i = 0; i < input_infos.size(); i++) { auto &tensor = input_buffers[i]; auto data_size = tensor->data_size(); auto dst_buffer = reinterpret_cast(tensor->mutable_data()); if (input_infos[i].is_no_batch_dim) { if (data_size != instances[0][i]->data_size()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Input " << i << " data size " << instances[0][i]->data_size() << "does not match size " << data_size << " defined in model"; } (void)memcpy_s(dst_buffer, data_size, instances[0][i]->data(), data_size); continue; } auto item_size = data_size / model_batch_size; for (size_t k = 0; k < input_batch_size; k++) { if (i >= instances[k].size()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << " Batch index " << k << " does not have input " << i; } if (item_size != instances[k][i]->data_size()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Input " << i << " Batch index " << k << " input data size " << instances[k][i]->data_size() << "does not match size " << item_size << " defined in model"; } (void)memcpy_s(dst_buffer + k * item_size, data_size - k * item_size, instances[k][i]->data(), item_size); } for (size_t k = input_batch_size; k < model_batch_size; k++) { (void)memcpy_s(dst_buffer + k * item_size, data_size - k * item_size, instances[0][i]->data(), item_size); } } return SUCCESS; } Status DirectModelLoaderBase::PostPredict(const ModelExecutorSubgraphInfo &subgraph_info, uint64_t model_batch_size, const std::vector &instances, const std::vector &predict_result, std::vector *instance_result) { auto input_batch_size = instances.size(); if (input_batch_size == 0 || input_batch_size > model_batch_size) { MSI_LOG_ERROR << "Input batch size " << input_batch_size << " invalid, model batch size " << model_batch_size; return SYSTEM_ERROR; } if (predict_result.size() != subgraph_info.output_infos.size()) { MSI_LOG_ERROR << "Output result count " << predict_result.size() << " not equal to outputs count " << subgraph_info.output_infos.size(); return SYSTEM_ERROR; } std::vector results_data(input_batch_size); auto &output = subgraph_info.output_infos; for (size_t i = 0; i < predict_result.size(); i++) { auto &item = predict_result[i]; auto &output_info = output[i]; if (item->data_size() != output_info.tensor_info.size) { MSI_LOG_ERROR << "Output result " << i << " data size " << item->data_size() << " not equal to size " << output_info.tensor_info.size << " in output_infos_ "; return SYSTEM_ERROR; } auto item_size = output_info.size_one_batch; auto shape = output_info.shape_one_batch; auto data_type = output_info.tensor_info.data_type; auto src_buffer = const_cast(item->data()); for (size_t k = 0; k < input_batch_size; k++) { auto tensor = std::make_shared(item, data_type, shape, src_buffer + item_size * k, item_size, true); results_data[k].data.push_back(tensor); } } *instance_result = results_data; return SUCCESS; } Status DirectModelLoaderBase::AfterLoadModel() { InitModelExecuteInfo(); return SUCCESS; } void DirectModelLoaderBase::InitModelExecuteInfo() { auto graph_num = GetGraphNum(); model_info_.sub_graph_infos.resize(graph_num); model_info_.batch_size = GetBatchSize(); for (uint64_t i = 0; i < graph_num; i++) { auto input_infos = GetInputInfos(i); auto output_infos = GetOutputInfos(i); auto &subgraph_info = model_info_.sub_graph_infos[i]; subgraph_info.input_infos = input_infos; for (auto &item : output_infos) { TensorInfoOutput info; info.tensor_info = item; if (item.is_no_batch_dim) { info.shape_one_batch = item.shape; info.size_one_batch = item.size; } else { info.shape_one_batch = item.shape; (void)info.shape_one_batch.erase(info.shape_one_batch.begin()); // the batch size has been checked in WorkerExecutor info.size_one_batch = item.size / model_info_.batch_size; } subgraph_info.output_infos.push_back(info); } // init input buffer subgraph_info.input_buffers.clear(); for (auto &input_info : subgraph_info.input_infos) { auto tensor = std::make_shared(); tensor->set_data_type(input_info.data_type); tensor->set_shape(input_info.shape); (void)tensor->resize_data(input_info.size); subgraph_info.input_buffers.push_back(tensor); } } } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/model_loader_base.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H #define MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H #include #include #include #include #include "common/serving_common.h" #include "common/instance_data.h" #include "common/servable.h" #include "worker/inference/inference.h" namespace mindspore::serving { class ModelLoaderBase { public: ModelLoaderBase() = default; virtual ~ModelLoaderBase() = default; virtual std::vector GetInputInfos(uint64_t subgraph) const = 0; virtual std::vector GetOutputInfos(uint64_t subgraph) const = 0; virtual uint64_t GetBatchSize() const = 0; virtual uint64_t GetGraphNum() const = 0; virtual void Clear() = 0; virtual Status Predict(const std::vector &inputs, std::vector *outputs, uint64_t subgraph) = 0; virtual Status AfterLoadModel() = 0; virtual bool OwnDevice() const = 0; }; struct ModelExecutorSubgraphInfo { std::vector input_infos; std::vector output_infos; std::vector input_buffers; }; struct ModelExecutorInfo { std::vector sub_graph_infos; uint64_t batch_size = 0; }; class MS_API DirectModelLoaderBase : public ModelLoaderBase { public: virtual Status Predict(const std::vector &input, std::vector *output, uint64_t subgraph) = 0; Status Predict(const std::vector &inputs, std::vector *outputs, uint64_t subgraph) override; Status AfterLoadModel() override; bool OwnDevice() const override { return true; } private: std::string model_key_; ModelExecutorInfo model_info_; void InitModelExecuteInfo(); Status PrePredict(const ModelExecutorSubgraphInfo &subgraph_info, uint64_t model_batch_size, const std::vector &instances); Status PostPredict(const ModelExecutorSubgraphInfo &subgraph_info, uint64_t model_batch_size, const std::vector &instances, const std::vector &predict_result, std::vector *instance_result); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H ================================================ FILE: mindspore_serving/ccsrc/worker/notfiy_master/base_notify.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_H #define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_H #include #include "common/serving_common.h" #include "common/servable.h" namespace mindspore { namespace serving { class MS_API BaseNotifyMaster { public: BaseNotifyMaster() = default; virtual ~BaseNotifyMaster() = default; virtual Status Register(const WorkerRegSpec &worker_spec) = 0; virtual Status Unregister() = 0; }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_H ================================================ FILE: mindspore_serving/ccsrc/worker/notfiy_master/grpc_notify.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/notfiy_master/grpc_notify.h" #include #include #include #include #include #include "common/grpc_server.h" #include "worker/servable_register.h" #include "common/shared_memory.h" #include "common/proto_tensor.h" namespace mindspore { namespace serving { GrpcNotifyMaster::GrpcNotifyMaster(const std::string &master_address, const std::string &worker_address) : master_address_(master_address), worker_address_(worker_address) { auto channel = GrpcServer::CreateChannel(master_address_); stub_ = proto::MSMaster::NewStub(channel); } GrpcNotifyMaster::~GrpcNotifyMaster() = default; Status GrpcNotifyMaster::Register(const WorkerRegSpec &worker_spec) { proto::RegisterRequest request; GrpcTensorHelper::ConvertWorkerSpec(worker_spec, &request); MSI_LOG(INFO) << "Register to " << master_address_; proto::RegisterReply reply; grpc::ClientContext context; const int32_t TIME_OUT = 1; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); grpc::Status status = stub_->Register(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Register SUCCESS "; is_running_ = true; return SUCCESS; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register failed, Grpc message: " << status.error_code() << ", " << status.error_message(); } Status GrpcNotifyMaster::Unregister() { if (!is_running_) { return SUCCESS; } is_running_ = false; proto::ExitRequest request; request.set_address(worker_address_); MSI_LOG(INFO) << "Unregister to " << master_address_; proto::ExitReply reply; grpc::ClientContext context; const int32_t TIME_OUT = 1; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); grpc::Status status = stub_->Exit(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Exit SUCCESS "; return SUCCESS; } return INFER_STATUS_LOG_WARNING(FAILED) << "Exit Failed, master may have exited, Grpc message: " << status.error_code() << ", " << status.error_message(); } Status GrpcNotifyMaster::NotifyFailed(const std::string &master_address, const std::string &error_msg) { proto::NotifyFailedRequest request; request.set_worker_pid(getpid()); request.set_error_msg(error_msg); auto channel = GrpcServer::CreateChannel(master_address); auto stub = proto::MSMaster::NewStub(channel); proto::NotifyFailedReply reply; grpc::ClientContext context; grpc::Status status = stub->NotifyFailed(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Success to notify master " << master_address << " error message of worker: " << error_msg; return SUCCESS; } MSI_LOG_WARNING << "Failed to notify master " << master_address << " error message of worker: " << error_msg << ", grpc error: " << status.error_message(); return FAILED; } Status GrpcNotifyMaster::GetModelInfos(const std::string &master_address, const std::string &servable_name, uint32_t version_number, proto::GetModelInfoReply *reply) { proto::GetModelInfoRequest request; request.set_servable_name(servable_name); request.set_version_number(version_number); auto channel = GrpcServer::CreateChannel(master_address); auto stub = proto::MSMaster::NewStub(channel); grpc::ClientContext context; grpc::Status grpc_status = stub->GetModelInfo(&context, request, reply); if (!grpc_status.ok()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Get model infos failed, master address:" << master_address << ", Grpc message: " << grpc_status.error_code() << ", " << grpc_status.error_message(); } return SUCCESS; } Status GrpcNotifyMaster::CreateRequestShmInstance(const RemoteCallModelContext &model_context, const InstanceData &instance, proto::Instance *proto_instance, std::vector *alloc_shm_request) { Status status; auto &memory_instance = SharedMemoryAllocator::Instance(); auto &proto_items = *(proto_instance->mutable_items()); for (size_t i = 0; i < instance.size(); i++) { auto &input = instance[i]; auto &memory_key = model_context.request_memory[i]; SharedMemoryItem memory_item; status = memory_instance.AllocMemoryItem(memory_key, &memory_item); if (status != SUCCESS) { MSI_LOG_ERROR << "Alloc request memory failed, memory: " << memory_key; return status; } alloc_shm_request->push_back(memory_item); auto &proto_tensor = proto_items["x" + std::to_string(i)]; // input: x0, x1, x2,... ProtoTensor tensor(&proto_tensor); tensor.set_data_type(input->data_type()); tensor.set_shape(input->shape()); auto proto_shm_data = proto_tensor.mutable_shm_data(); proto_shm_data->set_memory_key(memory_item.memory_key); proto_shm_data->set_bytes_size(memory_item.bytes_size); proto_shm_data->set_data_size(memory_item.size); proto_shm_data->set_data_offset(memory_item.offset); auto ret = memcpy_s(memory_item.offset_address, memory_item.size, input->data(), input->data_size()); if (ret != EOK) { return INFER_STATUS_LOG_ERROR(FAILED) << "Copy tensor to shared memory failed, dst size: " << memory_item.size << ", src size: " << input->data_size(); } } return SUCCESS; } Status GrpcNotifyMaster::CreateResultShmInstance(const RemoteCallModelContext &model_context, ResultInstance *result_instance, proto::Instance *proto_instance) { Status status; auto &memory_instance = SharedMemoryAllocator::Instance(); auto &proto_reply_items = *(proto_instance->mutable_output_buffers()); for (size_t i = 0; i < model_context.output_infos.size(); i++) { auto &output_info = model_context.output_infos[i]; auto &memory_key = model_context.reply_memory[i]; SharedMemoryItem memory_item; status = memory_instance.AllocMemoryItem(memory_key, &memory_item); if (status != SUCCESS) { MSI_LOG_ERROR << "Alloc request memory failed, memory: " << memory_key; return status; } auto &proto_output = proto_reply_items["y" + std::to_string(i)]; proto_output.set_memory_key(memory_item.memory_key); proto_output.set_bytes_size(memory_item.bytes_size); proto_output.set_data_size(memory_item.size); proto_output.set_data_offset(memory_item.offset); auto result_tensor = std::make_shared(output_info.tensor_info.data_type, output_info.shape_one_batch, memory_item); result_instance->data.push_back(result_tensor); } return SUCCESS; } Status GrpcNotifyMaster::CallModelInner(const RemoteCallModelContext &model_context, const std::vector &request, std::vector *reply, std::vector *alloc_shm_request) { proto::PredictRequest proto_request; auto servable_spec = proto_request.mutable_servable_spec(); servable_spec->set_name(ServableRegister::Instance().GetServableSignature().servable_name); servable_spec->set_method_name( ServableRegister::GetCallModelMethodName(model_context.model_name, model_context.subgraph)); servable_spec->set_version_number(model_context.version_number); auto proto_instances = proto_request.mutable_instances(); Status status; std::vector result_instances; for (auto &instance : request) { auto proto_instance = proto_instances->Add(); status = CreateRequestShmInstance(model_context, instance, proto_instance, alloc_shm_request); if (status != SUCCESS) { return status; } ResultInstance result_instance; status = CreateResultShmInstance(model_context, &result_instance, proto_instance); if (status != SUCCESS) { return status; } result_instances.push_back(result_instance); } proto::PredictReply proto_reply; MSI_TIME_STAMP_START(CallModel) grpc::ClientContext context; grpc::Status grpc_status = stub_->CallModel(&context, proto_request, &proto_reply); MSI_TIME_STAMP_END_EXTRA(CallModel, "Request count " + std::to_string(request.size())) if (!grpc_status.ok()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Remote call model failed, master address:" << master_address_ << ", Grpc message: " << grpc_status.error_code() << ", " << grpc_status.error_message(); } auto &error_msgs = proto_reply.error_msg(); auto &reply_instances = proto_reply.instances(); if (error_msgs.size() == 1 && error_msgs[0].error_code() != 0) { if (error_msgs[0].error_code() == SERVABLE_UNAVAILABLE) { return INFER_STATUS_LOG_ERROR(FAILED) << "There are no available inference processes that occupy devices"; } return INFER_STATUS_LOG_ERROR(FAILED) << "Remote call model failed: " << error_msgs[0].error_msg(); } if (!reply_instances.empty() && static_cast(reply_instances.size()) != request.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Remote call model failed, reply instances size " << reply_instances.size() << " is not equal to request instances size " << request.size(); } for (int i = 0; i < reply_instances.size(); i++) { ResultInstance result_instance; if (i < error_msgs.size() && error_msgs[i].error_code() != 0) { result_instance.error_msg = INFER_STATUS_LOG_ERROR(FAILED) << "Result instance " << i << "failed: " << error_msgs[i].error_msg(); } else { auto &proto_instance = reply_instances[i]; auto &proto_items = proto_instance.items(); for (auto &output : proto_items) { if (!output.second.has_shm_data()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Result instance " << i << " invalid, there no shared memory data"; } } result_instance.data = result_instances[i].data; } reply->push_back(result_instance); } return SUCCESS; } Status GrpcNotifyMaster::CallModel(const RemoteCallModelContext &model_context, const std::vector &request, std::vector *reply) { std::vector alloc_shm_request; auto status = CallModelInner(model_context, request, reply, &alloc_shm_request); auto &memory_instance = SharedMemoryAllocator::Instance(); for (auto &item : alloc_shm_request) { memory_instance.ReleaseMemoryItem(item); } return status; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/notfiy_master/grpc_notify.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_GRPC_NOTIFY_H #define MINDSPORE_SERVING_WORKER_GRPC_NOTIFY_H #include #include #include #include "worker/notfiy_master/base_notify.h" #include "common/instance_data.h" #include "common/shared_memory.h" #include "proto/ms_master.pb.h" #include "proto/ms_master.grpc.pb.h" #include "worker/extra_worker/remote_call_model.h" namespace mindspore { namespace serving { class MS_API GrpcNotifyMaster : public BaseNotifyMaster { public: GrpcNotifyMaster(const std::string &master_address, const std::string &worker_address); ~GrpcNotifyMaster() override; Status Register(const WorkerRegSpec &worker_spec) override; Status Unregister() override; static Status NotifyFailed(const std::string &master_address, const std::string &error_msg); Status CallModel(const RemoteCallModelContext &model_context, const std::vector &request, std::vector *reply); static Status GetModelInfos(const std::string &master_address, const std::string &servable_name, uint32_t version_number, proto::GetModelInfoReply *reply); private: std::string master_address_; std::string worker_address_; std::atomic is_running_ = false; std::unique_ptr stub_; Status CallModelInner(const RemoteCallModelContext &model_context, const std::vector &request, std::vector *reply, std::vector *alloc_shm_request); Status CreateRequestShmInstance(const RemoteCallModelContext &model_context, const InstanceData &instance, proto::Instance *proto_instance, std::vector *alloc_shm_request); Status CreateResultShmInstance(const RemoteCallModelContext &model_context, ResultInstance *result_instance, proto::Instance *proto_instance); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_GRPC_NOTIFY_H ================================================ FILE: mindspore_serving/ccsrc/worker/predict_thread.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/predict_thread.h" #include #include #include #include "worker/task_queue.h" #include "worker/stage_function.h" #include "common/buffer_tensor.h" #include "distributed_worker/distributed_model_loader.h" namespace mindspore::serving { serving::PredictThread::PredictThread() {} PredictThread::~PredictThread() noexcept { Stop(); } void PredictThread::PushPredictTask(const MethodStage &stage, const std::vector &inputs) { // create input for predict, and check std::vector valid_instances; for (auto &instance : inputs) { auto status = CheckPredictInput(stage.subgraph, instance); if (status != SUCCESS) { task_que_.PushTaskResult({instance}, status); continue; } valid_instances.push_back(instance); } if (!valid_instances.empty()) { auto group_name = AsGroupName(stage.stage_key, stage.subgraph); task_que_.PushTask(group_name, 0, valid_instances); } } void PredictThread::ThreadFunc(PredictThread *queue) { queue->Predict(); } void PredictThread::Predict() { while (true) { TaskItem task_item; task_que_.PopTask(&task_item); if (task_item.has_stopped) { MSI_LOG_INFO << "Predict task has stopped, exit predict thread"; break; } MSI_TIME_STAMP_START(InvokePredict) PredictHandle(task_item.task_info, task_item.instance_list); MSI_TIME_STAMP_END_EXTRA(InvokePredict, task_item.task_info.tag) } } void PredictThread::Stop() { task_que_.Stop(); for (auto &predict_thread : predict_threads_) { if (predict_thread.joinable()) { try { predict_thread.join(); } catch (const std::system_error &) { } catch (...) { } } } } std::string PredictThread::AsGroupName(const std::string &model_key, uint64_t subgraph) const { return model_key + "_subgraph" + std::to_string(subgraph); } void PredictThread::Start(const std::string &que_name, const std::shared_ptr &model_loader, const ModelMeta &model_meta, const TaskCallBack &task_callback) { MSI_EXCEPTION_IF_NULL(model_loader); MSI_EXCEPTION_IF_NULL(task_callback); model_loader_ = model_loader; model_meta_ = model_meta; auto &model_key = model_meta.common_meta.model_key; auto graph_num = model_loader_->GetGraphNum(); auto batch_size = model_loader->GetBatchSize(); // init executor info executor_info_.sub_graph_infos.resize(graph_num); executor_info_.batch_size = batch_size; for (uint64_t i = 0; i < graph_num; i++) { auto input_infos = model_loader_->GetInputInfos(i); auto &subgraph_info = executor_info_.sub_graph_infos[i]; subgraph_info.input_infos = input_infos; } // init task infos std::vector task_infos; for (uint64_t i = 0; i < graph_num; i++) { TaskInfo info; info.group_name = AsGroupName(model_key, i); info.subgraph = i; info.task_name = info.group_name; info.priority = 0; info.batch_size = batch_size; info.tag = "Model " + model_key + (graph_num > 1 ? " subgraph " + std::to_string(i) : ""); task_infos.push_back(info); } task_que_.Start(que_name, task_infos, task_callback); // start before predict_thread_ start bool support_pipeline_infer = model_meta.distributed_meta.enable_pipeline_infer && (std::dynamic_pointer_cast(model_loader) != nullptr); size_t thread_num = support_pipeline_infer ? model_meta.distributed_meta.stage_size : 1; for (size_t i = 0; i < thread_num; i++) { predict_threads_.emplace_back(ThreadFunc, this); } } void PredictThread::PredictHandle(const TaskInfo &task_info, const std::vector &instances) { Status status; try { std::vector instance_result; status = PredictInner(task_info, instances, &instance_result); if (status != SUCCESS) { task_que_.PushTaskResult(instances, status); return; } task_que_.PushTaskResult(instances, instance_result); return; } catch (const std::bad_alloc &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: malloc memory failed"; } catch (const std::runtime_error &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: runtime error occurred: " << ex.what(); } catch (const std::exception &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: exception occurred: " << ex.what(); } catch (...) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: exception occurred"; } task_que_.PushTaskResult(instances, status); } Status PredictThread::PredictInner(const TaskInfo &task_info, const std::vector &instances, std::vector *instance_result) { Status status; std::vector inputs; for (auto &item : instances) { // cppcheck-suppress useStlAlgorithm inputs.push_back(item->data); } status = model_loader_->Predict(inputs, instance_result, task_info.subgraph); if (status != SUCCESS) { MSI_LOG_ERROR << "Predict failed, model info " << model_meta_.common_meta.model_key; return status; } return SUCCESS; } Status PredictThread::CheckPredictInput(uint64_t subgraph, const InstancePtr &instance) { const auto &inputs_info = executor_info_.sub_graph_infos[subgraph].input_infos; if (instance->data.size() < inputs_info.size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Given model inputs size " << instance->data.size() << " less than model inputs size " << inputs_info.size(); } for (size_t i = 0; i < instance->data.size(); i++) { auto input_data = instance->data[i]; if (inputs_info[i].is_no_batch_dim) { if (static_cast(inputs_info[i].size) != input_data->data_size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Given model input " << i << " size " << input_data->data_size() << " not match the size " << inputs_info[i].size << " defined in model"; } } else if (static_cast(inputs_info[i].size / executor_info_.batch_size) != input_data->data_size()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Given model input " << i << " size " << input_data->data_size() << " not match the size " << inputs_info[i].size / executor_info_.batch_size << " defined in model"; } if (inputs_info[i].data_type != input_data->data_type()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Given model input " << i << " data type " << input_data->data_type() << " not match the data type " << inputs_info[i].data_type << " defined in model"; } } return SUCCESS; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/predict_thread.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_PREDICT_THREAD_H #define MINDSPORE_SERVING_WORKER_PREDICT_THREAD_H #include #include #include #include #include #include #include #include #include "common/instance.h" #include "worker/inference/inference.h" #include "worker/task_queue.h" #include "worker/model_loader_base.h" namespace mindspore::serving { struct PredictSubgraphInfo { std::vector input_infos; }; struct PredictModelInfo { std::vector sub_graph_infos; uint64_t batch_size = 0; }; class PredictThread { public: PredictThread(); ~PredictThread() noexcept; void PushPredictTask(const MethodStage &stage, const std::vector &inputs); void Start(const std::string &que_name, const std::shared_ptr &model_loader, const ModelMeta &model_meta, const TaskCallBack &task_callback); void Stop(); uint64_t GetBatchSize() const { return executor_info_.batch_size; } private: TaskQueue task_que_; std::vector predict_threads_; ModelMeta model_meta_; std::shared_ptr model_loader_ = nullptr; PredictModelInfo executor_info_; static void ThreadFunc(PredictThread *queue); void Predict(); void PredictHandle(const TaskInfo &task_info, const std::vector &instances); Status PredictInner(const TaskInfo &task_info, const std::vector &instances, std::vector *instance_result); Status CheckPredictInput(uint64_t subgraph, const InstancePtr &instance); std::string AsGroupName(const std::string &model_key, uint64_t subgraph) const; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_PREDICT_THREAD_H ================================================ FILE: mindspore_serving/ccsrc/worker/register/argmax.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/stage_function.h" #include "mindspore_serving/ccsrc/common/tensor.h" namespace mindspore::serving { class ArgmaxStageFunc : public CppStageFunctionBase { public: template void ArgmaxImp(const void *input, size_t *output, size_t data_size, size_t elemsize) { auto count = data_size / elemsize; auto data = reinterpret_cast(input); *output = 0; for (size_t i = 1; i < count; i++) { if (data[i] > data[*output]) { *output = i; } } } Status Call(const std::string &, const InstanceData &input, InstanceData *output) override { MSI_EXCEPTION_IF_NULL(output); auto input_x = input[0]; auto x_data = input_x->data(); auto out_tensor = std::make_shared(); out_tensor->set_data_type(serving::kMSI_Int64); (void)out_tensor->resize_data(sizeof(size_t)); out_tensor->set_shape({}); output->push_back(out_tensor); auto y_data = reinterpret_cast(out_tensor->mutable_data()); switch (input_x->data_type()) { case kMSI_Float32: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Float64: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Int8: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Uint8: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Int16: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Uint16: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Int32: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Uint32: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Int64: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; case kMSI_Uint64: ArgmaxImp(x_data, y_data, input_x->data_size(), input_x->itemsize()); break; default: return INFER_STATUS(FAILED) << "Argmax not support data type " << input_x->data_type(); } return SUCCESS; } size_t GetInputsCount(const std::string &) const override { return 1; } size_t GetOutputsCount(const std::string &) const override { return 1; } }; REGISTER_STAGE_FUNCTION(ArgmaxStageFunc, "argmax_cpp") } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/servable_register.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/servable_register.h" #include #include #include "worker/stage_function.h" namespace mindspore { namespace serving { ServableRegister &ServableRegister::Instance() { static ServableRegister storage = ServableRegister(); return storage; } Status ServableRegister::RegisterMethod(const MethodSignature &method) { MSI_LOG_INFO << "Declare method " << method.method_name << ", servable " << method.servable_name; servable_signatures_.servable_name = method.servable_name; for (auto &item : servable_signatures_.methods) { // cppcheck-suppress useStlAlgorithm if (item.method_name == method.method_name) { return INFER_STATUS_LOG_ERROR(FAILED) << "Method " << method.method_name << " has been registered more than once."; } } servable_signatures_.methods.push_back(method); return SUCCESS; } Status ServableRegister::DeclareModel(ModelMeta model) { auto &common_meta = model.common_meta; auto &local_meta = model.local_meta; MSI_LOG_INFO << "Declare model " << local_meta.model_files; if (servable_signatures_.servable_type == kServableTypeDistributed) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model failed, servable has already been declared as distributed servable"; } servable_signatures_.servable_name = common_meta.servable_name; servable_signatures_.servable_type = kServableTypeLocal; if (local_meta.model_files.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model failed, model files size cannot be 0"; } std::set cur_model_files; for (auto &model_item : servable_signatures_.model_metas) { for (auto &file_item : model_item.local_meta.model_files) { (void)cur_model_files.emplace(file_item); } } for (auto &file : local_meta.model_files) { if (file.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model " << local_meta.model_files << " failed, model file cannot be empty"; } if (cur_model_files.count(file) > 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model " << local_meta.model_files << " failed, model file '" << file << "' has already been used"; } } if (local_meta.model_format == ModelType::kUnknownType) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model " << local_meta.model_files << " failed, model_format is not inited"; } for (auto &item : servable_signatures_.model_metas) { // cppcheck-suppress useStlAlgorithm if (item.common_meta.model_key == common_meta.model_key) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare model " << local_meta.model_files << " failed, the same model has already been declared"; } } servable_signatures_.model_metas.push_back(model); return SUCCESS; } Status ServableRegister::DeclareDistributedModel(ModelMeta model) { auto &common_meta = model.common_meta; MSI_LOG_INFO << "Declare distributed model " << common_meta.servable_name; if (servable_signatures_.servable_type == kServableTypeDistributed) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare distributed model failed, servable is repeatedly been declared as distributed servable"; } if (servable_signatures_.servable_type == kServableTypeLocal) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare distributed model failed, servable has already been declared as local servable"; } servable_signatures_.servable_name = common_meta.servable_name; servable_signatures_.servable_type = kServableTypeDistributed; if (model.distributed_meta.rank_size == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare distributed model " << common_meta.servable_name << " failed, rank_size cannot be 0"; } if (model.distributed_meta.stage_size == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare distributed model " << common_meta.servable_name << " failed, stage_size cannot be 0"; } servable_signatures_.model_metas.push_back(model); return SUCCESS; } Status ServableRegister::RegisterInputOutputInfo(const std::string &model_key, size_t inputs_count, size_t outputs_count, uint64_t subgraph) { MSI_LOG_INFO << "Declare model " << model_key << " subgraph " << subgraph << " inputs count " << inputs_count << " outputs count " << outputs_count; auto &model_metas = servable_signatures_.model_metas; auto it = std::find_if(model_metas.begin(), model_metas.end(), [model_key](const ModelMeta &item) { return item.common_meta.model_key == model_key; }); if (it == model_metas.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find model " << model_key; } auto &common_meta = it->common_meta; if (common_meta.inputs_count.count(subgraph) > 0 && common_meta.inputs_count[subgraph] != inputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count " << common_meta.inputs_count[subgraph] << ", model: " << model_key; } if (common_meta.outputs_count.count(subgraph) > 0 && common_meta.outputs_count[subgraph] != outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count " << common_meta.outputs_count[subgraph] << ", model: " << model_key; } common_meta.inputs_count[subgraph] = inputs_count; common_meta.outputs_count[subgraph] = outputs_count; return SUCCESS; } Status ServableRegister::InitCallModelMethods(const std::map> &models) { for (auto &model_it : models) { auto model_key = model_it.first; auto &model_loader = model_it.second; auto graph_num = model_loader->GetGraphNum(); for (size_t subgraph = 0; subgraph < graph_num; subgraph++) { auto input_infos = model_loader->GetInputInfos(subgraph); auto output_infos = model_loader->GetOutputInfos(subgraph); auto status = RegisterOneCallModelMethod(model_key, input_infos.size(), output_infos.size(), subgraph); if (status != SUCCESS) { return status; } } } return SUCCESS; } std::string ServableRegister::GetCallModelMethodName(const std::string &model_key, uint64_t subgraph) { std::string method_name = "@call_" + model_key + "_" + std::to_string(subgraph); return method_name; } Status ServableRegister::RegisterOneCallModelMethod(const std::string &model_key, uint64_t input_count, uint64_t output_count, uint64_t subgraph) { std::string method_name = GetCallModelMethodName(model_key, subgraph); MethodSignature method; method.method_name = method_name; method.servable_name = servable_signatures_.servable_name; std::vector> model_inputs; for (uint64_t i = 0; i < input_count; i++) { (void)method.inputs.emplace_back("x" + std::to_string(i)); (void)model_inputs.emplace_back(std::make_pair(0, i)); // all method inputs are function inputs } std::vector> returns; for (uint64_t i = 0; i < output_count; i++) { (void)method.outputs.emplace_back("y" + std::to_string(i)); (void)returns.emplace_back(std::make_pair(1, i)); } method.AddStageModel(model_key, model_inputs, subgraph); method.SetReturn(returns); auto status = RegisterMethod(method); if (status != SUCCESS) { MSI_LOG_ERROR << "Register Method failed"; return status; } status = RegisterInputOutputInfo(model_key, input_count, output_count, subgraph); if (status != SUCCESS) { MSI_LOG_ERROR << "Register model input and output info failed"; return status; } return SUCCESS; } Status ServableRegister::CheckModels(const std::map> &models) { auto const &signature = servable_signatures_; if (signature.methods.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "There is no method registered for servable"; } if (models.size() != signature.model_metas.size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The number " << signature.model_metas.size() << " of models declared is not equal to the number " << models.size() << " of models loaded"; } for (auto &model_meta : signature.model_metas) { auto &model_key = model_meta.common_meta.model_key; auto model_load_it = models.find(model_key); if (model_load_it == models.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_key << " has not been loaded"; } auto &model_loader = model_load_it->second; auto batch_size = model_loader->GetBatchSize(); if (batch_size == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid batch size 0, model info: " << model_key; } auto graph_num = model_loader->GetGraphNum(); if (graph_num == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid subgraph number 0, model info: " << model_key; } for (uint64_t subgraph = 0; subgraph < graph_num; subgraph++) { auto input_infos = model_loader->GetInputInfos(subgraph); auto output_infos = model_loader->GetOutputInfos(subgraph); MSI_LOG_INFO << "Print model info, model info: '" << model_meta.common_meta.model_key << "', subgraph " << subgraph; MSI_LOG_INFO << "Model input infos: count " << input_infos.size(); for (auto &item : input_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } MSI_LOG_INFO << "Model output infos: count " << output_infos.size(); for (auto &item : output_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } const auto &common_meta = model_meta.common_meta; if (common_meta.inputs_count.count(subgraph) > 0 && input_infos.size() != common_meta.inputs_count.at(subgraph)) { return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count.at(subgraph) << " in register_method " << "not equal to the count " << input_infos.size() << " defined in model, model info: " << model_key << ", subgraph: " << subgraph; } if (common_meta.outputs_count.count(subgraph) > 0 && output_infos.size() != common_meta.outputs_count.at(subgraph)) { return INFER_STATUS_LOG_ERROR(FAILED) << "The outputs count " << common_meta.outputs_count.at(subgraph) << " in register_method " << "not equal to the count " << output_infos.size() << " defined in model, model info: " << model_key << ", subgraph: " << subgraph; } } } return SUCCESS; } Status ServableRegister::CheckOneMethod(const MethodSignature &method) { const auto &servable_name = servable_signatures_.servable_name; const auto &model_metas = servable_signatures_.model_metas; for (auto &stage_it : method.stage_map) { auto stage_index = stage_it.first; auto &stage = stage_it.second; for (size_t input_index = 0; input_index < stage.stage_inputs.size(); input_index++) { auto input_stage_index = stage.stage_inputs[input_index].first; auto output_index = stage.stage_inputs[input_index].second; // method input if (input_stage_index == 0) { if (output_index >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The stage " << stage_index << " " << input_index << "th input uses method " << output_index << "th input, that is greater than the method inputs size " << method.inputs.size() << ", servable: " << servable_name << ", method: " << method.method_name; } continue; } // check input stage index if (input_stage_index >= stage_index) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The " << input_index << "th input data of stage " << stage_index << " cannot not come from stage " << input_stage_index << ", servable: " << servable_name << ", method: " << method.method_name; } // check input stage output index auto it = method.stage_map.find(input_stage_index); if (it == method.stage_map.end()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Cannot find stage " << input_stage_index << " from method define information, " << ", servable: " << servable_name << ", method: " << method.method_name; } const auto &input_stage = it->second; if (input_stage.stage_type == kMethodStageTypePyFunction) { size_t input_count, output_count; if (!PyStageFunctionStorage::Instance()->GetPyFunctionInfo(input_stage.stage_key, &input_count, &output_count)) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "PyFunction " << input_stage.stage_key << " is not defined, " << ", servable: " << servable_name << ", method: " << method.method_name; } if (output_index >= output_count) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The stage(begin with 1) " << stage_index << " " << input_index << "th input uses python function " << input_stage.stage_key << " " << output_index << "th output, that is greater than the function output size " << output_count << ", servable: " << servable_name << ", method: " << method.method_name; } } else if (input_stage.stage_type == kMethodStageTypeCppFunction) { auto function = CppStageFunctionStorage::Instance().GetFunction(input_stage.stage_key); if (function == nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "CppFunction " << input_stage.stage_key << " is not defined, " << ", servable: " << servable_name << ", method: " << method.method_name; } auto func_output_count = function->GetOutputsCount(input_stage.stage_key); if (output_index >= func_output_count) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The stage(begin with 1) " << stage_index << " " << input_index << "th input uses c++ function " << input_stage.stage_key << " " << output_index << "th output, that is greater than the function output size " << func_output_count << ", servable: " << servable_name << ", method: " << method.method_name; } } else if (input_stage.stage_type == kMethodStageTypeModel) { auto model_it = std::find_if(model_metas.begin(), model_metas.end(), [&input_stage](const ModelMeta &model_meta) { return input_stage.stage_key == model_meta.common_meta.model_key; }); if (model_it == model_metas.end()) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Model " << input_stage.stage_key << " is not defined, " << ", servable: " << servable_name << ", method: " << method.method_name; } if (model_it->common_meta.outputs_count.count(input_stage.subgraph) == 0) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Model " << input_stage.stage_key << " subgraph " << input_stage.subgraph << " is not declared" << ", servable: " << servable_name << ", method: " << method.method_name; } auto model_output_count = model_it->common_meta.outputs_count.at(input_stage.subgraph); if (output_index >= model_output_count) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "The stage(begin with 1) " << stage_index << " " << input_index << "th input uses model " << input_stage.stage_key << " subgraph " << input_stage.subgraph << " " << output_index << "th output, that is greater than the model output size " << model_output_count << ", servable: " << servable_name << ", method: " << method.method_name; } } else { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Invalid stage type " << static_cast(stage.stage_type) << ", servable: " << servable_name << ", method: " << method.method_name; } } } return SUCCESS; } Status ServableRegister::CheckMethods() { std::set method_set; Status status; for (const auto &method : servable_signatures_.methods) { if (method_set.count(method.method_name) > 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_signatures_.servable_name << " method '" << method.method_name << "' has been defined repeatedly"; } (void)method_set.emplace(method.method_name); status = CheckOneMethod(method); if (status != SUCCESS) { return status; } } return SUCCESS; } Status ServableRegister::InitMethodBatchSize(const std::map> &models) { // stages only use method inputs as inputs batch_size == mini model batch size // other stages batch_size == max model batch size for (auto &method : servable_signatures_.methods) { uint64_t mini_batch_size = UINT32_MAX; uint64_t max_batch_size = 0; for (auto &stage_it : method.stage_map) { auto &stage = stage_it.second; if (stage.stage_type == kMethodStageTypeModel) { auto model_it = models.find(stage.stage_key); if (model_it == models.end()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << stage.stage_key << " has not been loaded"; } stage.batch_size = model_it->second->GetBatchSize(); if (stage.batch_size < mini_batch_size) { mini_batch_size = stage.batch_size; } if (stage.batch_size > max_batch_size) { max_batch_size = stage.batch_size; } } } if (mini_batch_size == UINT32_MAX || max_batch_size == 0) { mini_batch_size = 1; max_batch_size = 1; } for (auto &stage_it : method.stage_map) { auto &stage = stage_it.second; if (stage.stage_type != kMethodStageTypeModel && stage.batch_size == 0) { auto all_method_input = std::all_of(stage.stage_inputs.begin(), stage.stage_inputs.end(), [](const std::pair &item) { return item.first == 0; }); if (all_method_input) { stage.batch_size = mini_batch_size; } else { stage.batch_size = max_batch_size; } } } } return SUCCESS; } Status ServableRegister::InitOnModelsLoad(const std::map> &models) { Status status; status = CheckModels(models); if (status != SUCCESS) { MSI_LOG_ERROR << "Check models failed"; return status; } status = InitCallModelMethods(models); if (status != SUCCESS) { MSI_LOG_ERROR << "Init call model methods failed"; return status; } status = CheckMethods(); if (status != SUCCESS) { MSI_LOG_ERROR << "Check methods failed"; return status; } status = InitMethodBatchSize(models); if (status != SUCCESS) { MSI_LOG_ERROR << "Init models batch size failed"; return status; } return SUCCESS; } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/servable_register.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_SERVABLE_REGISTER_H #define MINDSPORE_SERVING_SERVABLE_REGISTER_H #include #include #include #include #include #include "common/servable.h" #include "worker/model_loader_base.h" namespace mindspore { namespace serving { class MS_API ServableRegister { public: static ServableRegister &Instance(); const ServableSignature &GetServableSignature() const { return servable_signatures_; } // register_method Status RegisterMethod(const MethodSignature &method); // call_model Status RegisterInputOutputInfo(const std::string &model_key, size_t inputs_count, size_t outputs_count, uint64_t subgraph = 0); // declare_model Status DeclareModel(ModelMeta model); Status DeclareDistributedModel(ModelMeta model); static std::string GetCallModelMethodName(const std::string &model_key, uint64_t subgraph); Status InitOnModelsLoad(const std::map> &models); private: ServableSignature servable_signatures_; Status RegisterOneCallModelMethod(const std::string &model_key, uint64_t input_count, uint64_t output_count, uint64_t subgraph); Status CheckModels(const std::map> &models); Status InitCallModelMethods(const std::map> &models); Status CheckMethods(); Status InitMethodBatchSize(const std::map> &models); Status CheckOneMethod(const MethodSignature &method); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_SERVABLE_REGISTER_H ================================================ FILE: mindspore_serving/ccsrc/worker/stage_function.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/stage_function.h" #include namespace mindspore::serving { bool CppStageFunctionStorage::Register(const std::string &function_name, std::shared_ptr function) { if (function_map_.find(function_name) != function_map_.end()) { MSI_LOG_WARNING << "function " << function_name << " has been registered"; return false; } function_map_[function_name] = std::move(function); return true; } void CppStageFunctionStorage::Unregister(const std::string &function_name) { auto it = function_map_.find(function_name); if (it == function_map_.end()) { return; } (void)function_map_.erase(it); } CppStageFunctionStorage &CppStageFunctionStorage::Instance() { static CppStageFunctionStorage storage = CppStageFunctionStorage(); return storage; } std::shared_ptr CppStageFunctionStorage::GetFunction(const std::string &func_name) const { auto it = function_map_.find(func_name); if (it != function_map_.end()) { return it->second; } return nullptr; } CppRegStageFunction::CppRegStageFunction(const std::string &function_name, std::shared_ptr function) { func_name_ = function_name; MSI_LOG_INFO << "Register C++ function " << function_name; register_success_ = CppStageFunctionStorage::Instance().Register(function_name, std::move(function)); } CppRegStageFunction::~CppRegStageFunction() noexcept { if (register_success_) { MSI_LOG_INFO << "Unregister C++ function " << func_name_; CppStageFunctionStorage::Instance().Unregister(func_name_); } } PyStageFunctionStorage::PyStageFunctionStorage() = default; PyStageFunctionStorage::~PyStageFunctionStorage() = default; std::shared_ptr PyStageFunctionStorage::Instance() { static std::shared_ptr instance = nullptr; if (instance == nullptr) { instance = std::make_shared(); } return instance; } void PyStageFunctionStorage::Register(const std::string &func_name, size_t inputs_count, size_t outputs_count) { function_infos_[func_name] = std::make_pair(inputs_count, outputs_count); MSI_LOG_INFO << "Register python stage function " << func_name << " inputs count " << inputs_count << " outputs count " << outputs_count; } bool PyStageFunctionStorage::HasPyFunction(const std::string &func_name) { auto it = function_infos_.find(func_name); return it != function_infos_.end(); } bool PyStageFunctionStorage::GetPyFunctionInfo(const std::string &func_name, size_t *inputs_count, size_t *outputs_count) { MSI_EXCEPTION_IF_NULL(inputs_count); MSI_EXCEPTION_IF_NULL(outputs_count); auto it = function_infos_.find(func_name); if (it == function_infos_.end()) { return false; } *inputs_count = it->second.first; *outputs_count = it->second.second; return true; } std::vector PyStageFunctionStorage::GetPyCppFunctionInfo(const std::string &func_name) const { size_t inputs_count = 0; size_t outputs_count = 0; if (PyStageFunctionStorage::Instance()->GetPyFunctionInfo(func_name, &inputs_count, &outputs_count)) { return {inputs_count, outputs_count}; } auto function = CppStageFunctionStorage::Instance().GetFunction(func_name); if (!function) { return {}; } inputs_count = function->GetInputsCount(func_name); outputs_count = function->GetOutputsCount(func_name); if (inputs_count == 0 || outputs_count == 0) { MSI_LOG_ERROR << "Call " + func_name + " inputs or outputs count cannot be 0"; return {}; } return {inputs_count, outputs_count}; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/stage_function.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_STAGE_FUNCTION_PY_H #define MINDSPORE_SERVING_WORKER_STAGE_FUNCTION_PY_H #include #include #include #include #include #include "common/serving_common.h" #include "common/instance.h" namespace mindspore::serving { class CppStageFunctionBase : public std::enable_shared_from_this { public: CppStageFunctionBase() = default; virtual ~CppStageFunctionBase() = default; virtual Status Call(const std::string &func_name, const InstanceData &input, InstanceData *output) = 0; virtual size_t GetInputsCount(const std::string &func_name) const = 0; virtual size_t GetOutputsCount(const std::string &func_name) const = 0; }; class CppStageFunctionStorage { public: bool Register(const std::string &func_name, std::shared_ptr function); void Unregister(const std::string &func_name); std::shared_ptr GetFunction(const std::string &func_name) const; static CppStageFunctionStorage &Instance(); private: std::unordered_map> function_map_; }; class CppRegStageFunction { public: CppRegStageFunction(const std::string &func_name, std::shared_ptr function); ~CppRegStageFunction() noexcept; private: std::string func_name_; bool register_success_ = false; }; #define REGISTER_STAGE_FUNCTION(cls_name, func_name) \ static CppRegStageFunction g_register_stage_function_##cls_name(func_name, std::make_shared()); class MS_API PyStageFunctionStorage { public: static std::shared_ptr Instance(); void Register(const std::string &func_name, size_t inputs_count, size_t outputs_count); bool HasPyFunction(const std::string &func_name); bool GetPyFunctionInfo(const std::string &func_name, size_t *inputs_count, size_t *outputs_count); std::vector GetPyCppFunctionInfo(const std::string &func_name) const; PyStageFunctionStorage(); ~PyStageFunctionStorage(); private: std::unordered_map> function_infos_; }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_STAGE_FUNCTION_PY_H ================================================ FILE: mindspore_serving/ccsrc/worker/task_queue.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/task_queue.h" #include #include #include "worker/stage_function.h" namespace mindspore::serving { TaskQueue::TaskQueue() {} void TaskQueue::Start(const std::string &que_name, const std::vector &task_infos, const TaskCallBack &callback) { std::unique_lock lock{que_lock_}; if (is_running) { return; } que_name_ = que_name; task_callback_ = callback; methods_queue_.group_que_map.clear(); methods_queue_.groups_que_instances_count = 0; for (auto &info : task_infos) { if (info.batch_size == 0) { MSI_LOG_EXCEPTION << "Invalid batch size 0, queue name: " << que_name; } auto &method_queue = methods_queue_.group_que_map[info.group_name]; auto &stage_queue = method_queue.priority_que_map[info.priority]; stage_queue.task_info = info; } is_running = true; } void TaskQueue::Stop() { std::unique_lock lock{que_lock_}; if (!is_running) { return; } methods_queue_.group_que_map.clear(); task_callback_ = nullptr; is_running = false; cond_var_.notify_all(); } void TaskQueue::PushTask(const std::string &group_name, size_t priority, const std::vector &instances) { if (instances.empty()) { MSI_LOG_WARNING << "Instances cannot be empty"; return; } MSI_LOG_DEBUG << que_name_ << " Push instances count " << instances.size() << ", inputs size: " << instances[0]->data.size(); { std::unique_lock lock{que_lock_}; auto method_it = methods_queue_.group_que_map.find(group_name); if (method_it == methods_queue_.group_que_map.end()) { MSI_LOG_EXCEPTION << "Cannot find method " << group_name << " in task queue, queue name: " << que_name_; } auto &stage_queue = method_it->second; auto stage_it = stage_queue.priority_que_map.find(priority); if (stage_it == stage_queue.priority_que_map.end()) { MSI_LOG_EXCEPTION << "Cannot find stage index " << priority << " in task queue, method name: " << group_name << ", queue name: " << que_name_; } auto &que = stage_it->second; for (auto &instance : instances) { que.instance_list.push_back(instance); } stage_queue.priority_que_instances_count += instances.size(); methods_queue_.groups_que_instances_count += instances.size(); } cond_var_.notify_all(); } bool TaskQueue::FindProcessTaskQueue(std::string *method_name) { auto next_que = methods_queue_.next_exe_que; auto &que_map = methods_queue_.group_que_map; size_t index = 0; std::string name; for (auto &item : que_map) { if (item.second.priority_que_instances_count > 0 && (name.empty() || index >= next_que)) { name = item.first; if (index >= next_que) { break; } } index++; } if (name.empty()) { return false; } if (index + 1 >= que_map.size()) { methods_queue_.next_exe_que = 0; } else { methods_queue_.next_exe_que = index + 1; } *method_name = name; return true; } void TaskQueue::PopTask(TaskItem *task_item) { MSI_EXCEPTION_IF_NULL(task_item); std::unique_lock lock{que_lock_}; if (!is_running) { // before start, or after stop MSI_LOG_INFO << "Detect task queue is not running, maybe the Serving server is stopped."; task_item->has_stopped = true; return; } while (true) { if (methods_queue_.groups_que_instances_count == 0) { cond_var_.wait(lock, [this] { return !is_running || methods_queue_.groups_que_instances_count > 0; }); if (!is_running) { MSI_LOG_INFO << "Detect task queue '" << que_name_ << "' is not running, maybe the Serving server is stopped."; task_item->has_stopped = true; return; } } std::string method_name; if (!FindProcessTaskQueue(&method_name)) { MSI_LOG_EXCEPTION << "Cannot find task when the number " << methods_queue_.groups_que_instances_count << " of instances in task queue is not 0"; } auto &method_que = methods_queue_.group_que_map[method_name]; auto &stage_que_map = method_que.priority_que_map; auto stage_it = stage_que_map.rbegin(); for (; stage_it != stage_que_map.rend(); ++stage_it) { if (!stage_it->second.instance_list.empty()) { break; } } if (stage_it == stage_que_map.rend()) { MSI_LOG_EXCEPTION << "Cannot find task when the number " << method_que.priority_que_instances_count << " of instances in method task queue is not 0"; } auto &task_handle = stage_it->second; auto batch_size = task_handle.task_info.batch_size; // Pop a maximum of batch_size instances if (task_handle.instance_list.size() <= batch_size) { *task_item = task_handle; task_handle.instance_list.clear(); } else { *task_item = task_handle; auto &instances_ret = task_item->instance_list; (void)instances_ret.erase(instances_ret.begin() + static_cast(batch_size), instances_ret.end()); auto &instances_reserved = task_handle.instance_list; (void)instances_reserved.erase(instances_reserved.begin(), instances_reserved.begin() + static_cast(batch_size)); } MSI_LOG_DEBUG << que_name_ << " Pop instances count " << task_item->instance_list.size() << ", batch size: " << batch_size; method_que.priority_que_instances_count -= task_item->instance_list.size(); methods_queue_.groups_que_instances_count -= task_item->instance_list.size(); break; } } void TaskQueue::PushTaskResult(const InstancePtr &input, const ResultInstance &output) { if (!is_running) { MSI_LOG_INFO << "Task queue has exited"; return; } task_callback_({input}, {output}); } void TaskQueue::PushTaskResult(const std::vector &inputs, const std::vector &outputs) { if (!is_running) { MSI_LOG_INFO << "Task queue has exited"; return; } task_callback_(inputs, outputs); } void TaskQueue::PushTaskResult(const std::vector &inputs, const Status &failed_result) { std::vector result; for (auto &item : inputs) { (void)item; ResultInstance output; output.error_msg = failed_result; result.push_back(output); } PushTaskResult(inputs, result); } void PyTaskQueue::Start(const std::string &que_name, const std::vector &stage_infos, const TaskCallBack &callback) { std::vector task_infos; for (auto &item : stage_infos) { TaskInfo info; info.batch_size = item.batch_size; info.priority = item.stage_index; info.group_name = item.method_name; info.task_name = item.stage_key; info.tag = item.tag; task_infos.push_back(info); } task_queue_.Start(que_name, task_infos, callback); py_task_item_processing_ = TaskItem(); } void PyTaskQueue::Stop() { task_queue_.Stop(); } void PyTaskQueue::PushTask(const std::string &method_name, size_t stage_index, const std::vector &instances) { task_queue_.PushTask(method_name, stage_index, instances); } void PyTaskQueue::PyPopTask(TaskItem *task_item) { MSI_EXCEPTION_IF_NULL(task_item); task_queue_.PopTask(task_item); if (!task_item->has_stopped) { py_task_item_processing_ = *task_item; } } void PyTaskQueue::PyPushTaskResult(const std::vector &outputs) { if (!task_queue_.IsRunning()) { MSI_LOG_INFO << "Task queue has exited"; return; } auto &instance_list = py_task_item_processing_.instance_list; if (outputs.empty() || instance_list.size() < outputs.size()) { MSI_LOG_EXCEPTION << "processing task not match result, processing size " << instance_list.size() << ", result size " << outputs.size(); } std::vector instances; std::vector results; for (size_t i = 0; i < outputs.size(); i++) { instances.push_back(instance_list[i]); results.push_back(outputs[i]); } task_queue_.PushTaskResult(instances, results); (void)instance_list.erase(instance_list.begin(), instance_list.begin() + static_cast(outputs.size())); } CppTaskQueueThreadPool::CppTaskQueueThreadPool() = default; CppTaskQueueThreadPool::~CppTaskQueueThreadPool() = default; void CppTaskQueueThreadPool::ThreadFunc(CppTaskQueueThreadPool *thread_pool) { while (true) { TaskItem task_item; thread_pool->task_queue_.PopTask(&task_item); if (task_item.has_stopped) { return; } auto status = thread_pool->HandleTask(task_item); if (status != SUCCESS) { MSI_LOG_ERROR << "System error happens, thread exit"; return; } } } void CppTaskQueueThreadPool::Start(const std::string &que_name, const std::vector &stage_infos, const TaskCallBack &callback, uint32_t size) { if (is_running_) { return; } is_running_ = true; // start before ThreadFunc thread pool start std::vector task_infos; for (auto &item : stage_infos) { TaskInfo info; info.batch_size = item.batch_size; info.priority = item.stage_index; info.group_name = item.method_name; info.task_name = item.stage_key; info.tag = item.tag; task_infos.push_back(info); } task_queue_.Start(que_name, task_infos, callback); // start before ThreadFunc thread pool start for (uint32_t i = 0; i < size; ++i) { (void)pool_.emplace_back(ThreadFunc, this); } } void CppTaskQueueThreadPool::Stop() { task_queue_.Stop(); for (std::thread &thd : pool_) { if (thd.joinable()) { try { thd.join(); } catch (const std::system_error &) { } catch (...) { } } } pool_.clear(); is_running_ = false; } void CppTaskQueueThreadPool::PushTask(const std::string &method_name, size_t stage_index, const std::vector &instances) { task_queue_.PushTask(method_name, stage_index, instances); } Status CppTaskQueueThreadPool::HandleTask(const TaskItem &task_item) { Status status; auto &task_name = task_item.task_info.task_name; auto preprocess = CppStageFunctionStorage::Instance().GetFunction(task_name); if (!preprocess) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "System error, get preprocess " << task_name << " failed"; return status; } for (const auto &instance : task_item.instance_list) { ResultInstance result; try { status = preprocess->Call(task_name, instance->data, &result.data); } catch (const std::bad_alloc &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: malloc memory failed"; } catch (const std::runtime_error &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: runtime error occurred: " << ex.what(); } catch (const std::exception &ex) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: exception occurred: " << ex.what(); } catch (...) { status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: exception occurred"; } if (status != SUCCESS) { result.error_msg = status; } task_queue_.PushTaskResult(instance, result); } return SUCCESS; } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/task_queue.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_TASK_QUEUE_H #define MINDSPORE_SERVING_WORKER_TASK_QUEUE_H #include #include #include #include #include #include #include #include #include #include #include #include "common/instance.h" namespace mindspore::serving { struct TaskInfo { std::string group_name; // method name std::string task_name; // function name, model name uint64_t priority = 0; uint64_t batch_size = 0; uint64_t subgraph = 0; // for model std::string tag; }; struct TaskItem { bool has_stopped = false; // whether system is stopped TaskInfo task_info; std::vector instance_list; }; using TaskCallBack = std::function &inputs, const std::vector &output)>; struct TaskQueuePriority { std::map priority_que_map; // priority: stage index, task list uint64_t priority_que_instances_count = 0; }; struct TaskQueueGroups { std::map group_que_map; // group name: method name, task que size_t next_exe_que = 0; // next method index uint64_t groups_que_instances_count = 0; }; class TaskQueue { public: TaskQueue(); void Start(const std::string &que_name, const std::vector &task_infos, const TaskCallBack &callback); void Stop(); void PushTask(const std::string &group_name, size_t priority, const std::vector &instances); void PopTask(TaskItem *task_item); void PushTaskResult(const InstancePtr &input, const ResultInstance &output); void PushTaskResult(const std::vector &inputs, const std::vector &outputs); void PushTaskResult(const std::vector &inputs, const Status &failed_result); bool IsRunning() const { return is_running; } private: std::string que_name_; TaskQueueGroups methods_queue_; TaskCallBack task_callback_ = nullptr; std::mutex que_lock_; // Lock only when the queue changes to avoid deadlock caused by lock in complex scenarios. std::condition_variable cond_var_; bool is_running = false; bool FindProcessTaskQueue(std::string *method_name); }; class MS_API PyTaskQueue { public: PyTaskQueue() = default; ~PyTaskQueue() = default; void Start(const std::string &que_name, const std::vector &stage_infos, const TaskCallBack &callback); void Stop(); void PushTask(const std::string &method_name, size_t stage_index, const std::vector &instances); // for python task void PyPopTask(TaskItem *task_item); void PyPushTaskResult(const std::vector &outputs); TaskInfo GetHandledTaskInfo() const { return py_task_item_processing_.task_info; } bool IsRunning() const { return task_queue_.IsRunning(); } private: TaskQueue task_queue_; TaskItem py_task_item_processing_; }; class CppTaskQueueThreadPool { public: CppTaskQueueThreadPool(); virtual ~CppTaskQueueThreadPool(); void Start(const std::string &que_name, const std::vector &stage_infos, const TaskCallBack &callback, uint32_t size = 4); void Stop(); void PushTask(const std::string &method_name, size_t stage_index, const std::vector &instances); protected: TaskQueue task_queue_; std::atomic is_running_ = false; std::vector pool_; Status HandleTask(const TaskItem &task_item); static void ThreadFunc(CppTaskQueueThreadPool *thread_pool); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_TASK_QUEUE_H ================================================ FILE: mindspore_serving/ccsrc/worker/work_executor.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/work_executor.h" #include #include #include #include #include #include "worker/stage_function.h" #include "common/tensor.h" #include "worker/servable_register.h" namespace mindspore::serving { WorkExecutor::WorkExecutor() = default; WorkExecutor::~WorkExecutor() noexcept { Stop(); } Status WorkExecutor::Init(const std::map> &model_loaders) { Status status; if (init_flag_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker service has been initialized"; } // servable can be nullptr model_loaders_ = model_loaders; status = ServableRegister::Instance().InitOnModelsLoad(model_loaders); if (status != SUCCESS) { MSI_LOG_ERROR << "Init on models load failed"; return status; } InitStageFunctionQueue(); InitPredictTaskQueue(); init_flag_ = true; return SUCCESS; } void WorkExecutor::StageCallback(const std::vector &instances, const std::vector &outputs) { if (instances.empty() || instances.size() != outputs.size()) { MSI_LOG_ERROR << "Invalid inputs size " << instances.size() << ", result size " << outputs.size(); return; } // > std::map>> outputs_real; for (size_t i = 0; i < instances.size(); i++) { auto &instance = instances[i]; auto &output = outputs[i]; if (output.error_msg != SUCCESS) { (void)ReplyError(instance, output.error_msg); continue; } CreateResultInstance(instance, output); outputs_real[instance->method_def->method_name][instance->stage_index].push_back(instance); } for (auto &method_instances_it : outputs_real) { for (auto &stage_instances_it : method_instances_it.second) { auto &stage_instances = stage_instances_it.second; if (!stage_instances.empty()) { auto &method_def = *stage_instances[0]->method_def; auto stage_index = stage_instances_it.first; OnReceiveStageInputs(method_def, stage_index + 1, stage_instances); } } } } void WorkExecutor::InitStageFunctionQueue() { // init cpp preprocess and postprocess auto stage_callback = [this](const std::vector &instances, const std::vector &outputs) { StageCallback(instances, outputs); }; auto const &signature = ServableRegister::Instance().GetServableSignature(); // start task queue for handle preprocess and postprocess std::vector py_stage_infos; std::vector cpp_stage_infos; for (auto &method : signature.methods) { for (auto &stage_it : method.stage_map) { auto &stage = stage_it.second; if (stage.stage_type == kMethodStageTypePyFunction) { MSI_LOG_INFO << "PyFunction stage " << stage.stage_key << ", method name: " << stage.method_name << ", stage index: " << stage.stage_index << ", batch size: " << stage.batch_size; py_stage_infos.push_back(stage); } else if (stage.stage_type == kMethodStageTypeCppFunction) { MSI_LOG_INFO << "CppFunction stage " << stage.stage_key << ", method name: " << stage.method_name << ", stage index: " << stage.stage_index << ", batch size: " << stage.batch_size; cpp_stage_infos.push_back(stage); } } } if (!py_stage_infos.empty()) { py_task_queue_.Start("PyTask", py_stage_infos, stage_callback); } if (!cpp_stage_infos.empty()) { cpp_task_queue_pool_.Start("CppTask", cpp_stage_infos, stage_callback, 3); // 3 thread } } void WorkExecutor::InitPredictTaskQueue() { auto stage_callback = [this](const std::vector &instances, const std::vector &outputs) { StageCallback(instances, outputs); }; auto const &signature = ServableRegister::Instance().GetServableSignature(); for (auto &model_meta : signature.model_metas) { auto model_key = model_meta.common_meta.model_key; auto &thread = predict_thread_map_[model_key]; // insert thread.Start("PredictTask", model_loaders_[model_key], model_meta, stage_callback); } } void WorkExecutor::Stop() { init_flag_ = false; for (auto &item : predict_thread_map_) { item.second.Stop(); } predict_thread_map_.clear(); ClearInstances(Status(WORKER_UNAVAILABLE, "Servable stopped")); for (auto &model : model_loaders_) { model.second->Clear(); } model_loaders_.clear(); py_task_queue_.Stop(); cpp_task_queue_pool_.Stop(); } Status WorkExecutor::Work(const RequestSpec &request_spec, const std::vector &instances_data, const WorkCallBack &on_process_done) { if (!init_flag_) { MSI_LOG_EXCEPTION << "Worker service has not been initialized"; } auto user_id = WorkExecutor::GetNextUserId(); InferSession infer_session; infer_session.call_back = on_process_done; auto const &signature = ServableRegister::Instance().GetServableSignature(); auto method_def = signature.GetMethodDeclare(request_spec.method_name); if (method_def == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Not support method " << request_spec.method_name; } std::vector instances; for (size_t i = 0; i < instances_data.size(); i++) { if (method_def->inputs.size() != instances_data[i].size()) { return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << instances_data[i].size() << " of instance " << i << " is not equal to the inputs count " << method_def->inputs.size() << " of the method " << request_spec.method_name; } auto instance = std::make_shared(); instances.push_back(instance); instance->method_def = method_def; instance->stage_data_list[0] = instances_data[i]; // stage 0 data: input instance->stage_max = method_def->GetStageMax(); instance->user_id = user_id; } infer_session.instances = instances; { std::unique_lock lock(infer_session_map_mutex_); infer_session_map_[user_id] = infer_session; } OnReceiveStageInputs(*method_def, kStageStartIndex, instances); // stage 1 is the first stage return SUCCESS; } void WorkExecutor::OnReceiveStageInputs(const MethodSignature &method_def, uint64_t stage_index, const std::vector &instances) { if (instances.empty()) { MSI_LOG_EXCEPTION << "Inputs cannot be empty"; } auto stage_it = method_def.stage_map.find(stage_index); if (stage_it == method_def.stage_map.end()) { MSI_LOG_EXCEPTION << "Cannot find stage " << stage_index; } auto &stage = stage_it->second; CreateInputInstance(stage, instances); if (stage_index >= method_def.GetStageMax()) { (void)ReplyRequest(instances); return; } if (stage.stage_type == kMethodStageTypePyFunction) { py_task_queue_.PushTask(method_def.method_name, stage_index, instances); } else if (stage.stage_type == kMethodStageTypeCppFunction) { cpp_task_queue_pool_.PushTask(method_def.method_name, stage_index, instances); } else if (stage.stage_type == kMethodStageTypeModel) { auto it = predict_thread_map_.find(stage.stage_key); if (it == predict_thread_map_.end()) { MSI_LOG_EXCEPTION << "Cannot find model " << stage.stage_key << " in predict_thread_map_"; } it->second.PushPredictTask(stage, instances); } else { MSI_LOG_EXCEPTION << "Invalid stage type " << static_cast(stage.stage_type); } } bool WorkExecutor::ReplyRequest(const std::vector &outputs) { MSI_TIME_STAMP_START(ReplyRequest) for (auto &item : outputs) { (void)ReplyRequest(item); } MSI_TIME_STAMP_END(ReplyRequest) return true; } bool WorkExecutor::ReplyCallback(const InstancePtr &instance) { instance->stage_data_list.clear(); instance->stage_index = instance->stage_max; std::unique_lock lock(infer_session_map_mutex_); auto it = infer_session_map_.find(instance->user_id); if (it == infer_session_map_.end()) { MSI_LOG_WARNING << "Cannot find user in session map, user id " << instance->user_id; return false; } auto &infer_session = it->second; infer_session.reply_count++; if (infer_session.reply_count == infer_session.instances.size()) { infer_session.call_back(infer_session.instances); (void)infer_session_map_.erase(it); } return true; } bool WorkExecutor::ReplyRequest(const InstancePtr &instance) { instance->error_msg = SUCCESS; return ReplyCallback(instance); } bool WorkExecutor::ReplyError(const InstancePtr &instance, const Status &error_msg) { instance->error_msg = error_msg; instance->data.clear(); return ReplyCallback(instance); } void WorkExecutor::CreateInputInstance(const MethodStage &stage, const std::vector &instances) { for (auto &instance : instances) { CreateInputInstance(stage, instance); } } void WorkExecutor::CreateInputInstance(const MethodStage &stage, const InstancePtr &instance) { instance->data.clear(); const auto &inputs = stage.stage_inputs; instance->stage_index = stage.stage_index; for (auto &item : inputs) { if (item.first >= instance->stage_data_list.size()) { MSI_LOG_EXCEPTION << "Invalid input stage index " << item.first << ", data stage count " << instance->stage_data_list.size(); } auto &data = instance->stage_data_list[item.first]; if (data.size() <= item.second) { MSI_LOG_EXCEPTION << "Invalid output index " << item.second << ", output count " << data.size() << ", input stage index " << item.first << ", stage index " << stage.stage_index << ", method " << stage.method_name; } instance->data.push_back(data[item.second]); } } void WorkExecutor::CreateResultInstance(const InstancePtr &instance, const ResultInstance &result) { instance->data.clear(); auto stage_index = instance->stage_index; instance->stage_data_list[stage_index] = result.data; } uint64_t WorkExecutor::GetNextUserId() { static std::atomic user_id = 0; return ++user_id; } uint64_t WorkExecutor::GetMaxBatchSize() const { uint64_t batch_size = 1; for (auto &model : predict_thread_map_) { auto model_batch = model.second.GetBatchSize(); if (model_batch > batch_size) { batch_size = model_batch; } } return batch_size; } void WorkExecutor::ClearInstances(const Status &error_msg) { std::unique_lock lock(infer_session_map_mutex_); MSI_LOG_INFO << "Clear instances, remain request count " << infer_session_map_.size(); for (auto &item : infer_session_map_) { auto &infer_session = item.second; for (auto &instance : infer_session.instances) { if (instance->stage_index != instance->stage_max) { instance->error_msg = error_msg; } } item.second.call_back(item.second.instances); } infer_session_map_.clear(); } } // namespace mindspore::serving ================================================ FILE: mindspore_serving/ccsrc/worker/work_executor.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_WORK_EXECUTOR_H #define MINDSPORE_SERVING_WORKER_WORK_EXECUTOR_H #include #include #include #include #include #include "common/serving_common.h" #include "common/instance.h" #include "common/servable.h" #include "worker/model_loader_base.h" #include "worker/predict_thread.h" #include "worker/task_queue.h" namespace mindspore::serving { using WorkCallBack = std::function &instances)>; struct InferSession { std::vector instances; size_t reply_count = 0; WorkCallBack call_back = nullptr; }; class WorkExecutor : public std::enable_shared_from_this { public: WorkExecutor(); ~WorkExecutor() noexcept; Status Init(const std::map> &model_loaders); Status Work(const RequestSpec &request_spec, const std::vector &inputs, const WorkCallBack &on_process_done); void Stop(); static uint64_t GetNextUserId(); void ClearInstances(const Status &error_msg); uint64_t GetMaxBatchSize() const; PyTaskQueue &GetPyTaskQueue() { return py_task_queue_; } private: std::map> model_loaders_; bool init_flag_ = false; std::map predict_thread_map_; PyTaskQueue py_task_queue_; CppTaskQueueThreadPool cpp_task_queue_pool_; std::map infer_session_map_; std::mutex infer_session_map_mutex_; bool ReplyCallback(const InstancePtr &instance); bool ReplyError(const InstancePtr &context, const Status &error_msg); bool ReplyRequest(const std::vector &outputs); bool ReplyRequest(const InstancePtr &outputs); void OnReceiveStageInputs(const MethodSignature &method_def, uint64_t stage_index, const std::vector &instances); static void CreateInputInstance(const MethodStage &stage, const InstancePtr &instance); static void CreateInputInstance(const MethodStage &stage, const std::vector &instances); static void CreateResultInstance(const InstancePtr &instance, const ResultInstance &result); void StageCallback(const std::vector &instances, const std::vector &outputs); void InitStageFunctionQueue(); void InitPredictTaskQueue(); }; } // namespace mindspore::serving #endif // MINDSPORE_SERVING_WORKER_WORK_EXECUTOR_H ================================================ FILE: mindspore_serving/ccsrc/worker/worker.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/worker.h" #include #include #include #include "pybind11/pybind11.h" #include "common/proto_tensor.h" #include "common/exit_handle.h" #include "worker/context.h" #include "worker/grpc/worker_process.h" #include "worker/task_queue.h" #include "worker/grpc/worker_server.h" #include "worker/servable_register.h" namespace py = pybind11; namespace mindspore { namespace serving { Worker &Worker::GetInstance() { static Worker instance; return instance; } Status Worker::RegisterWorker(const std::string &master_address, const std::string &worker_address) { notify_master_ = std::make_shared(master_address, worker_address); WorkerRegSpec worker_spec; worker_spec.servable_spec = servable_spec_; worker_spec.worker_address = worker_address; worker_spec.worker_pid = getpid(); auto status = notify_master_->Register(worker_spec); return status; } Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { Status status; RequestSpec request_spec; GrpcTensorHelper::GetRequestSpec(request, &request_spec); auto servable_name = request_spec.servable_name; auto method_name = request_spec.method_name; const ServableSignature &servable_signature = ServableRegister::Instance().GetServableSignature(); if (servable_signature.servable_name != servable_name) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Servable " << servable_name << " is not declared"; } auto method_signature = servable_signature.GetMethodDeclare(method_name); if (method_signature == nullptr) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Method " << method_name << " is not registered for servable " << servable_name; } const MethodSignature &method = *method_signature; std::vector instances_data; status = GrpcTensorHelper::CreateInstanceFromRequest(method, request, &instances_data); if (status != SUCCESS) { MSI_LOG(ERROR) << "transfer request to instances failed"; return status; } *(reply->mutable_servable_spec()) = request.servable_spec(); WorkCallBack on_process_done = [&request, reply, on_finish, method](const std::vector &instances) { GrpcTensorHelper::CreateReplyFromInstances(request, method, instances, reply); on_finish(); }; return RunAsync(request_spec, instances_data, on_process_done); } Status Worker::RunAsync(const RequestSpec &request_spec, const std::vector &instances_data, const WorkCallBack &on_process_done) { while (true) { // avoid deadlock when Worker::Clear->gRPC shutdown, while gRPC shutdown waiting all request finished if (worker_shared_lock_.try_lock_shared()) { auto status = RunAsyncInner(request_spec, instances_data, on_process_done); worker_shared_lock_.unlock_shared(); return status; } else if (!servable_started_) { return INFER_STATUS_LOG_ERROR(WORKER_UNAVAILABLE) << "RunAsync worker for inference failed, worker has not been started or stopped"; } std::chrono::milliseconds duration(1); // 1ms std::this_thread::sleep_for(duration); } } Status Worker::RunAsyncInner(const RequestSpec &request_spec, const std::vector &instances_data, const WorkCallBack &on_process_done) { if (!servable_started_) { return INFER_STATUS_LOG_ERROR(WORKER_UNAVAILABLE) << "RunAsync worker for inference failed, worker has not been started or stopped"; } if (instances_data.empty()) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Input instances count is 0"; } if (!CheckServableRequest(request_spec)) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find servable match " << request_spec.Repr(); } MSI_LOG_INFO << "New request, method: " << request_spec.method_name << ", instances count: " << instances_data.size(); return worker_executor_.Work(request_spec, instances_data, on_process_done); } Status Worker::Run(const RequestSpec &request_spec, const std::vector &instances_data, std::vector *out) { if (!servable_started_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Run worker for inference failed, worker has not been started"; } MSI_EXCEPTION_IF_NULL(out); auto promise = std::make_shared>(); auto future = promise->get_future(); WorkCallBack on_process_done = [promise, out](const std::vector &instances) { *out = instances; promise->set_value(); }; auto status = RunAsync(request_spec, instances_data, on_process_done); if (status != SUCCESS) { return status; } future.get(); return SUCCESS; } Status Worker::StartGrpcServer(const std::string &server_address) { if (worker_grpc_server_ != nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; } worker_grpc_server_ = std::make_shared(); SSLConfig ssl_config; return worker_grpc_server_->Start(server_address, ssl_config, gRpcMaxMBMsgSize, "Worker gRPC"); } Status Worker::StartDistributedGrpcServer(std::shared_ptr servable, const std::string &server_address) { if (distributed_grpc_server_ != nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Distributed gRPC server is already running"; } distributed_grpc_server_ = std::make_shared(servable, server_address); SSLConfig ssl_config; return distributed_grpc_server_->Start(server_address, ssl_config, gRpcMaxMBMsgSize, "Distributed gRPC"); } Status Worker::StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, const std::map> &models, const std::string &master_address, const std::string &worker_address, bool own_device) { auto status = StartServableInner(servable_name, version_number, models, own_device); if (status != SUCCESS) { return status; } status = StartGrpcServer(worker_address); if (status != SUCCESS) { return status; } status = RegisterWorker(master_address, worker_address); if (status != SUCCESS) { return status; } status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory << "', servable name: '" << servable_name << "', version number: " << version_number; MSI_LOG_INFO << status.StatusMessage(); std::cout << status.StatusMessage() << std::endl; return SUCCESS; } Status Worker::StartServableInner(const std::string &servable_name, uint32_t version_number, const std::map> &models, bool own_device) { if (servable_started_) { return INFER_STATUS_LOG_ERROR(FAILED) << "A servable has been started, only one servable can run in a process currently."; } clear_flag_.clear(); auto status = worker_executor_.Init(models); if (status != SUCCESS) { return status; } servable_spec_.servable_name = servable_name; servable_spec_.version_number = version_number; servable_spec_.batch_size = worker_executor_.GetMaxBatchSize(); servable_spec_.methods.clear(); servable_spec_.own_device = own_device; for (auto &model_it : models) { ModelInfo model_info; auto &model_key = model_it.first; auto &model = model_it.second; model_info.batch_size = model->GetBatchSize(); auto graph_num = model->GetGraphNum(); model_info.sub_graph_infos.resize(graph_num); for (uint64_t i = 0; i < graph_num; i++) { model_info.sub_graph_infos[i].input_infos = model->GetInputInfos(i); model_info.sub_graph_infos[i].output_infos = model->GetOutputInfos(i); } servable_spec_.models[model_key] = model_info; } const ServableSignature &signature = ServableRegister::Instance().GetServableSignature(); for (auto &method : signature.methods) { ServableMethodInfo worker_method_info; bool has_model = false; bool has_func = false; for (auto &stage : method.stage_map) { if (stage.second.stage_type == kMethodStageTypeModel) { has_model = true; } else if (stage.second.stage_type == kMethodStageTypePyFunction || stage.second.stage_type == kMethodStageTypeCppFunction) { has_func = true; } } if (has_model && !has_func) { worker_method_info.only_model_stage = true; } else { worker_method_info.only_model_stage = false; } // This worker does not occupy device and is only used to run python function stage to support python parallelism. // If one method does not contain function stage, requests of this method do not need to routed to this // worker. if (!servable_spec_.own_device && worker_method_info.only_model_stage) { continue; } worker_method_info.name = method.method_name; for (auto &name : method.inputs) { worker_method_info.input_names.push_back(name); } servable_spec_.methods.push_back(worker_method_info); } servable_started_ = true; return SUCCESS; } void Worker::StopServable(bool notify_master) { exit_notify_master_ = notify_master; ExitSignalHandle::Instance().Stop(); } void Worker::Clear() { std::unique_lock lock(worker_shared_lock_); MSI_LOG_INFO << "Start clear worker session"; servable_started_ = false; worker_executor_.Stop(); if (exit_notify_master_ && notify_master_) { notify_master_->Unregister(); } if (worker_grpc_server_) { worker_grpc_server_->Stop(); worker_grpc_server_ = nullptr; } if (distributed_grpc_server_) { distributed_grpc_server_->Stop(); distributed_grpc_server_ = nullptr; } MSI_LOG_INFO << "End clear worker session"; } bool Worker::IsRunning() { return servable_started_; } Worker::~Worker() { Clear(); if (listening_parent_thread_.joinable()) { listening_parent_thread_.join(); } } bool Worker::CheckServableRequest(const RequestSpec &request_spec) { if (servable_spec_.servable_name != request_spec.servable_name) { return false; } if (request_spec.version_number != 0 && servable_spec_.version_number != request_spec.version_number) { return false; } return true; } Worker::Worker() {} void Worker::ClearOnSystemFailed(const Status &error_msg) { std::shared_lock lock(worker_shared_lock_); MSI_LOG_INFO << "Clear instances on system failed: " << error_msg.StatusMessage(); worker_executor_.ClearInstances(error_msg); } static std::vector GetAllChildrenPids(int cur_pid) { if (cur_pid <= 0) { return {}; } std::string get_all_children_pids = "ps -o pid --no-headers --ppid " + std::to_string(cur_pid); FILE *fp = popen(get_all_children_pids.c_str(), "r"); if (fp == nullptr) { return {}; } constexpr int max_result_size = 1024; char buf[max_result_size] = {0}; std::string cmd_result; while (fgets(buf, max_result_size, fp) != nullptr && cmd_result.size() <= max_result_size) { cmd_result += std::string(buf) + " "; } pclose(fp); if (cmd_result.size() == max_result_size || cmd_result.empty()) { return {}; } std::regex pid_reg("[0-9]+"); auto match_beg = std::sregex_iterator(cmd_result.begin(), cmd_result.end(), pid_reg); auto match_end = std::sregex_iterator(); if (match_beg == match_end) { return {}; } std::vector direct_children; for (auto item = match_beg; item != match_end; ++item) { auto pid_str = item->str(); auto pid = static_cast(std::strtol(pid_str.c_str(), nullptr, 10)); if (pid <= 0) { continue; } std::ifstream stat_fp("/proc/" + std::to_string(pid) + "/stat"); if (!stat_fp.is_open()) { continue; } constexpr int cache_size_max = 128; char cache[cache_size_max + 1] = {0}; stat_fp.read(cache, cache_size_max); std::string cache_str = cache; auto pos = cache_str.find(") "); if (pos == std::string::npos) { continue; } cache_str = cache_str.substr(pos + strlen(") S ")); int child_ppid = static_cast(std::strtol(cache_str.c_str(), nullptr, 10)); if (child_ppid != cur_pid) { continue; } direct_children.push_back(pid); } std::vector all_pids = direct_children; for (auto &pid : direct_children) { auto pids = GetAllChildrenPids(pid); all_pids.insert(all_pids.end(), pids.begin(), pids.end()); } return all_pids; } void Worker::StartListeningParentExitThread() { auto thread_func = [this]() { MSI_LOG_INFO << "Start listening parent"; auto init_parent_pid = getppid(); constexpr int sleep_period_in_ms = 100; constexpr int try_kill_children_times = 100; // exit when receive SIGINT SIGTERM, or parent process exit while (true) { if (ExitSignalHandle::Instance().HasStopped()) { MSI_LOG_WARNING << "Worker has received exit message, worker begin to exit"; break; } auto cur_parent_pid = getppid(); if (init_parent_pid != cur_parent_pid) { MSI_LOG_WARNING << "Worker detect parent pid=" << init_parent_pid << " has exited, worker begin to exit"; ExitSignalHandle::Instance().Stop(); break; } std::this_thread::sleep_for(std::chrono::milliseconds(sleep_period_in_ms)); } Clear(); auto cur_pid = getpid(); for (int i = 0; i < try_kill_children_times; i++) { // 100*100ms=10s auto child_pids = GetAllChildrenPids(cur_pid); if (child_pids.empty() && !continue_listen_children_) { break; } for (auto pid : child_pids) { kill(pid, SIGTERM); } std::this_thread::sleep_for(std::chrono::milliseconds(sleep_period_in_ms)); } MSI_LOG_INFO << "Stop listening parent"; }; listening_parent_thread_ = std::thread(thread_func); } } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/ccsrc/worker/worker.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_WORKER_WORKER_H #define MINDSPORE_SERVING_WORKER_WORKER_H #include #include #include #include #include #include #include #include "worker/work_executor.h" #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "worker/notfiy_master/grpc_notify.h" #include "common/grpc_server.h" #include "worker/task_queue.h" #include "common/grpc_async_server.h" #include "worker/model_loader_base.h" #include "worker/grpc/worker_server.h" #include "worker/distributed_worker/distributed_process/distributed_server.h" namespace mindspore { namespace serving { class MS_API Worker { public: Worker(); ~Worker(); static Worker &GetInstance(); void Clear(); Status Run(const RequestSpec &request_spec, const std::vector &instances_data, std::vector *out); Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish); Status RunAsync(const RequestSpec &request_spec, const std::vector &instances_data, const WorkCallBack &on_process_done); Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, const std::map> &models, const std::string &master_address, const std::string &worker_address, bool own_device); Status StartGrpcServer(const std::string &server_address); Status StartDistributedGrpcServer(std::shared_ptr servable, const std::string &server_address); void StopServable(bool notify_master = true); bool IsRunning(); Status RegisterWorker(const std::string &master_address, const std::string &worker_address); WorkExecutor &GetWorkExecutor() { return worker_executor_; } void ClearOnSystemFailed(const Status &error_msg); std::shared_ptr GetGrpcNotifyMaster() { return notify_master_; } void SetContinueListenChildren(bool continue_listen_children) { continue_listen_children_ = continue_listen_children; } void StartListeningParentExitThread(); private: WorkExecutor worker_executor_; ServableRegSpec servable_spec_; std::atomic_bool exit_notify_master_ = true; std::atomic_bool servable_started_ = false; std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; std::shared_ptr notify_master_ = nullptr; std::shared_ptr worker_grpc_server_ = nullptr; std::shared_ptr distributed_grpc_server_ = nullptr; std::shared_mutex worker_shared_lock_; bool continue_listen_children_ = false; std::thread listening_parent_thread_; Status StartServableInner(const std::string &servable_name, uint32_t version_number, const std::map> &models, bool own_device); Status RunAsyncInner(const RequestSpec &request_spec, const std::vector &instances_data, const WorkCallBack &on_process_done); bool CheckServableRequest(const RequestSpec &request_spec); }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_WORKER_WORKER_H ================================================ FILE: mindspore_serving/client/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore Serving Client API, which can be used to access the Serving Server through gRPC""" from .python.client import Client from .python.client import SSLConfig __all__ = [] __all__.extend([ "Client", "SSLConfig" ]) ================================================ FILE: mindspore_serving/client/cpp/client.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "client/cpp/client.h" #include #include #include #include #include #include #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" namespace mindspore { namespace serving { namespace client { Status &Status::operator<<(DataType val) { std::unordered_map data_type_map = { {DT_UINT8, "uint8"}, {DT_UINT16, "uint16"}, {DT_UINT32, "uint32"}, {DT_UINT64, "uint64"}, {DT_INT8, "int8"}, {DT_INT16, "int16"}, {DT_INT32, "int32"}, {DT_INT64, "int64"}, {DT_BOOL, "bool"}, {DT_FLOAT16, "float16"}, {DT_FLOAT32, "float32"}, {DT_FLOAT64, "float64"}, {DT_STRING, "string"}, {DT_BYTES, "bytes"}, {DT_UNKNOWN, "unknown"}, }; auto it = data_type_map.find(val); if (it == data_type_map.end()) { status_msg_ += "unknown"; } else { status_msg_ += it->second; } return *this; } Status &operator<<(Status &status, proto::DataType val) { std::unordered_map data_type_map = { {proto::MS_UINT8, "uint8"}, {proto::MS_UINT16, "uint16"}, {proto::MS_UINT32, "uint32"}, {proto::MS_UINT64, "uint64"}, {proto::MS_INT8, "int8"}, {proto::MS_INT16, "int16"}, {proto::MS_INT32, "int32"}, {proto::MS_INT64, "int64"}, {proto::MS_BOOL, "bool"}, {proto::MS_FLOAT16, "float16"}, {proto::MS_FLOAT32, "float32"}, {proto::MS_FLOAT64, "float64"}, {proto::MS_STRING, "string"}, {proto::MS_BYTES, "bytes"}, {proto::MS_UNKNOWN, "unknown"}, }; auto it = data_type_map.find(val); if (it == data_type_map.end()) { status << "unknown"; } else { status << it->second; } return status; } Status &operator<<(Status &status, grpc::StatusCode val) { std::unordered_map data_type_map = { {grpc::OK, "OK"}, {grpc::CANCELLED, "CANCELLED"}, {grpc::UNKNOWN, "UNKNOWN"}, {grpc::INVALID_ARGUMENT, "INVALID_ARGUMENT"}, {grpc::DEADLINE_EXCEEDED, "DEADLINE_EXCEEDED"}, {grpc::NOT_FOUND, "NOT_FOUND"}, {grpc::ALREADY_EXISTS, "ALREADY_EXISTS"}, {grpc::PERMISSION_DENIED, "PERMISSION_DENIED"}, {grpc::UNAUTHENTICATED, "UNAUTHENTICATED"}, {grpc::RESOURCE_EXHAUSTED, "RESOURCE_EXHAUSTED"}, {grpc::FAILED_PRECONDITION, "FAILED_PRECONDITION"}, {grpc::ABORTED, "ABORTED"}, {grpc::OUT_OF_RANGE, "OUT_OF_RANGE"}, {grpc::UNIMPLEMENTED, "UNIMPLEMENTED"}, {grpc::INTERNAL, "INTERNAL"}, {grpc::UNAVAILABLE, "UNAVAILABLE"}, {grpc::DATA_LOSS, "DATA_LOSS"}, }; auto it = data_type_map.find(val); if (it == data_type_map.end()) { status << "unknown"; } else { status << it->second; } return status; } Status MutableTensor::SetBytesData(const std::vector &val) { if (mutable_proto_tensor_ == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } auto proto_shape = mutable_proto_tensor_->mutable_shape(); proto_shape->add_dims(1); mutable_proto_tensor_->set_dtype(proto::MS_BYTES); if (val.empty()) { return Status(INVALID_INPUTS) << "Input index bytes val len is empty"; } mutable_proto_tensor_->add_bytes_val(val.data(), val.size()); return SUCCESS; } Status MutableTensor::SetStrData(const std::string &val) { if (mutable_proto_tensor_ == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } auto proto_shape = mutable_proto_tensor_->mutable_shape(); proto_shape->add_dims(val.size()); mutable_proto_tensor_->set_dtype(proto::MS_STRING); if (val.empty()) { return Status(INVALID_INPUTS) << "string index string val len is empty"; } mutable_proto_tensor_->add_bytes_val(val); return SUCCESS; } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(uint8_t), shape, DT_UINT8); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(uint16_t), shape, DT_UINT16); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(uint32_t), shape, DT_UINT32); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(uint64_t), shape, DT_UINT64); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(int8_t), shape, DT_INT8); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(int16_t), shape, DT_INT16); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(int32_t), shape, DT_INT32); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(int64_t), shape, DT_INT64); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { std::vector val_uint8; std::transform(val.begin(), val.end(), std::back_inserter(val_uint8), [](bool item) { return static_cast(item); }); return SetData(val_uint8.data(), val_uint8.size() * sizeof(bool), shape, DT_BOOL); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(float), shape, DT_FLOAT32); } Status MutableTensor::SetData(const std::vector &val, const std::vector &shape) { return SetData(val.data(), val.size() * sizeof(double), shape, DT_FLOAT64); } Status MutableTensor::SetData(const void *data, size_t data_len, const std::vector &shape, DataType data_type) { if (mutable_proto_tensor_ == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } if (data == nullptr || data_len == 0) { return Status(INVALID_INPUTS) << "data cannot be nullptr, or data len cannot be 0"; } mutable_proto_tensor_->set_data(data, data_len); auto proto_shape = mutable_proto_tensor_->mutable_shape(); std::unordered_map> data_type_map = { {DT_UINT8, {proto::MS_UINT8, sizeof(uint8_t)}}, {DT_UINT16, {proto::MS_UINT16, sizeof(uint16_t)}}, {DT_UINT32, {proto::MS_UINT32, sizeof(uint32_t)}}, {DT_UINT64, {proto::MS_UINT64, sizeof(uint64_t)}}, {DT_INT8, {proto::MS_INT8, sizeof(int8_t)}}, {DT_INT16, {proto::MS_INT16, sizeof(int16_t)}}, {DT_INT32, {proto::MS_INT32, sizeof(int32_t)}}, {DT_INT64, {proto::MS_INT64, sizeof(int64_t)}}, {DT_BOOL, {proto::MS_BOOL, sizeof(bool)}}, {DT_FLOAT16, {proto::MS_FLOAT16, 2}}, {DT_FLOAT32, {proto::MS_FLOAT32, 4}}, {DT_FLOAT64, {proto::MS_FLOAT64, 8}}, }; auto it = data_type_map.find(data_type); if (it == data_type_map.end()) { return Status(INVALID_INPUTS) << "Input unsupported find data type " << data_type; } mutable_proto_tensor_->set_dtype(it->second.first); auto shape_str = [](const std::vector &val) noexcept { std::stringstream sstream; sstream << "["; for (size_t i = 0; i < val.size(); i++) { sstream << val[i]; if (i + 1 < val.size()) { sstream << ", "; } } sstream << "]"; return sstream.str(); }; int64_t element_cnt = 1; for (auto &item : shape) { proto_shape->add_dims(item); if (item <= 0 || item >= INT64_MAX || INT64_MAX / element_cnt < item) { return Status(INVALID_INPUTS) << "Input input shape invalid " << shape_str(shape); } } auto item_size = it->second.second; if (static_cast(data_len) / element_cnt < item_size || element_cnt * item_size != static_cast(data_len)) { return Status(INVALID_INPUTS) << "Input input shape " << shape_str(shape) << " does not match data len " << data_len; } return SUCCESS; } Status Tensor::GetBytesData(std::vector *val) const { if (val == nullptr) { return Status(SYSTEM_ERROR) << "input val cannot be nullptr"; } if (proto_tensor_ == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } if (proto_tensor_->dtype() != proto::MS_BYTES) { return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor_->dtype(); } auto &bytes_data = proto_tensor_->bytes_val(); if (bytes_data.size() != 1) { return Status(INVALID_INPUTS) << "Bytes value type size can only be 1"; } val->resize(bytes_data[0].size()); memcpy(val->data(), val->data(), bytes_data[0].size()); return SUCCESS; } Status Tensor::GetStrData(std::string *val) const { if (val == nullptr) { return Status(SYSTEM_ERROR) << "input val cannot be nullptr"; } if (proto_tensor_ == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } if (proto_tensor_->dtype() != proto::MS_STRING) { return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor_->dtype(); } auto &bytes_data = proto_tensor_->bytes_val(); if (bytes_data.size() != 1) { return Status(INVALID_INPUTS) << "String value type size can only be 1"; } val->resize(bytes_data[0].size()); memcpy(val->data(), val->data(), bytes_data[0].size()); return SUCCESS; } template Status GetInputImp(const proto::Tensor *proto_tensor, std::vector
*val) { if (val == nullptr) { return Status(SYSTEM_ERROR) << "input val cannot be nullptr"; } if (proto_tensor == nullptr) { return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr"; } if (proto_tensor->dtype() != proto_dtype) { return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor->dtype(); } auto data = proto_tensor->data().data(); auto data_len = proto_tensor->data().length(); val->resize(data_len / sizeof(DT)); memcpy(val->data(), data, data_len); return SUCCESS; } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { if (val == nullptr) { return Status(SYSTEM_ERROR) << "input val cannot be nullptr"; } std::vector val_uint8; Status status = GetInputImp(proto_tensor_, &val_uint8); if (!status.IsSuccess()) { return status; } std::transform(val_uint8.begin(), val_uint8.end(), std::back_inserter(*val), [](uint8_t item) { return item != 0; }); return SUCCESS; } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetData(std::vector *val) const { return GetInputImp(proto_tensor_, val); } Status Tensor::GetFp16Data(std::vector *val) const { return GetInputImp(proto_tensor_, val); } DataType Tensor::GetDataType() const { if (proto_tensor_ == nullptr) { std::cout << "proto tensor cannot be nullptr" << std::endl; return DT_UNKNOWN; } std::unordered_map data_type_map = { {proto::MS_UNKNOWN, DT_UNKNOWN}, {proto::MS_UINT8, DT_UINT8}, {proto::MS_UINT16, DT_UINT16}, {proto::MS_UINT32, DT_UINT32}, {proto::MS_UINT64, DT_UINT64}, {proto::MS_INT8, DT_INT8}, {proto::MS_INT16, DT_INT16}, {proto::MS_INT32, DT_INT32}, {proto::MS_INT64, DT_INT64}, {proto::MS_BOOL, DT_BOOL}, {proto::MS_FLOAT16, DT_FLOAT16}, {proto::MS_FLOAT32, DT_FLOAT32}, {proto::MS_FLOAT64, DT_FLOAT64}, {proto::MS_STRING, DT_STRING}, {proto::MS_BYTES, DT_BYTES}, }; auto it_dt = data_type_map.find(proto_tensor_->dtype()); if (it_dt == data_type_map.end()) { std::cout << "Unsupported data type " << proto_tensor_->dtype() << std::endl; return DT_UNKNOWN; } return it_dt->second; } std::vector Tensor::GetShape() const { if (proto_tensor_ == nullptr) { std::cout << "proto tensor cannot be nullptr" << std::endl; return std::vector(); } std::vector shape; auto &dims = proto_tensor_->shape().dims(); std::copy(dims.begin(), dims.end(), std::back_inserter(shape)); return shape; } Tensor Instance::Get(const std::string &item_name) const { if (proto_instance_ == nullptr) { std::cout << "proto instance cannot be nullptr" << std::endl; return Tensor(nullptr, nullptr); } auto &items = proto_instance_->items(); auto it = items.find(item_name); if (it == items.end()) { std::cout << "Cannot find item name " << item_name << std::endl; return Tensor(nullptr, nullptr); } return Tensor(message_owner_, &it->second); } bool Instance::HasErrorMsg(int64_t *error_code, std::string *error_msg) const { if (error_code == nullptr) { return false; } if (error_msg == nullptr) { return false; } if (error_msg_ == nullptr) { return false; } *error_code = error_msg_->error_code(); *error_msg = error_msg_->error_msg(); return true; } MutableTensor MutableInstance::Add(const std::string &item_name) { if (mutable_proto_instance_ == nullptr) { std::cout << "proto instance cannot be nullptr" << std::endl; return MutableTensor(nullptr, nullptr); } auto items = mutable_proto_instance_->mutable_items(); auto &proto_tensor = (*items)[item_name]; return MutableTensor(message_owner_, &proto_tensor); } InstancesRequest::InstancesRequest() { request_ = std::make_shared(); } MutableInstance InstancesRequest::AddInstance() { auto proto_instance = request_->add_instances(); return MutableInstance(request_, proto_instance); } InstancesReply::InstancesReply() { reply_ = std::make_shared(); } std::vector InstancesReply::GetResult() const { std::vector instances; auto &proto_instances = reply_->instances(); auto &proto_error_msgs = reply_->error_msg(); for (int i = 0; i < proto_instances.size(); i++) { auto &proto_instance = proto_instances[i]; const proto::ErrorMsg *error_msg = nullptr; if (proto_error_msgs.size() == 1) { error_msg = &proto_error_msgs[0]; } else if (proto_error_msgs.size() == proto_instances.size() && proto_error_msgs[i].error_code() != 0) { error_msg = &proto_error_msgs[i]; } instances.push_back(Instance(reply_, &proto_instance, error_msg)); } return instances; } class ClientImpl { public: ClientImpl(const std::string &server_ip, uint64_t server_port) { std::string target_str = server_ip + ":" + std::to_string(server_port); auto channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); stub_ = proto::MSService::NewStub(channel); } Status Predict(const proto::PredictRequest &request, proto::PredictReply *reply) { if (reply == nullptr) { return Status(SYSTEM_ERROR, "ClientImpl::Predict input reply cannot be nullptr"); } grpc::ClientContext context; // The actual RPC. grpc::Status status = stub_->Predict(&context, request, reply); if (status.ok()) { return SUCCESS; } else { std::cout << status.error_code() << ": " << status.error_message() << std::endl; return Status(FAILED, status.error_message()); } } private: std::unique_ptr stub_; }; Client::Client(const std::string &server_ip, uint64_t server_port, const std::string &servable_name, const std::string &method_name, uint64_t version_number) : server_ip_(server_ip), server_port_(server_port), servable_name_(servable_name), method_name_(method_name), version_number_(version_number), impl_(std::make_shared(server_ip, server_port)) {} Status Client::SendRequest(const InstancesRequest &request, InstancesReply *reply) { if (reply == nullptr) { return Status(SYSTEM_ERROR) << "input reply cannot be nullptr"; } proto::PredictRequest *proto_request = request.request_.get(); proto::PredictReply *proto_reply = reply->reply_.get(); auto servable_spec = proto_request->mutable_servable_spec(); servable_spec->set_name(servable_name_); servable_spec->set_method_name(method_name_); servable_spec->set_version_number(version_number_); Status result = impl_->Predict(*proto_request, proto_reply); return result; } } // namespace client } // namespace serving } // namespace mindspore ================================================ FILE: mindspore_serving/client/cpp/client.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_CLIENT_H #define MINDSPORE_SERVING_CLIENT_H #include #include #include #include namespace google { namespace protobuf { class Message; } } // namespace google namespace mindspore { namespace serving { #define MS_API __attribute__((visibility("default"))) namespace proto { class Tensor; class Instance; class PredictRequest; class PredictReply; class ErrorMsg; } // namespace proto namespace client { using ProtoMsgOwner = std::shared_ptr; enum DataType { DT_UNKNOWN, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT32, DT_FLOAT64, DT_STRING, DT_BYTES, }; enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS, SYSTEM_ERROR, UNAVAILABLE }; class MS_API Status { public: Status() : status_code_(FAILED) {} Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) : status_code_(status_code), status_msg_(status_msg) {} bool IsSuccess() const { return status_code_ == SUCCESS; } enum StatusCode StatusCode() const { return status_code_; } std::string StatusMessage() { return status_msg_; } bool operator==(const Status &other) const { return status_code_ == other.status_code_; } bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } operator bool() const = delete; template Status &operator<<(T val); Status &operator<<(DataType val); template Status &operator<<(const std::vector &val); private: enum StatusCode status_code_; std::string status_msg_; }; class MS_API Tensor { public: Tensor(const ProtoMsgOwner &owner, const proto::Tensor *proto_tensor) : message_owner_(owner), proto_tensor_(proto_tensor) {} virtual ~Tensor() = default; // Bytes type: for images etc. Status GetBytesData(std::vector *val) const; Status GetStrData(std::string *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetData(std::vector *val) const; Status GetFp16Data(std::vector *val) const; DataType GetDataType() const; std::vector GetShape() const; bool IsValid() const { return proto_tensor_ != nullptr; } protected: ProtoMsgOwner message_owner_; private: const proto::Tensor *proto_tensor_; }; class MS_API MutableTensor : public Tensor { public: MutableTensor(const ProtoMsgOwner &owner, proto::Tensor *proto_tensor) : Tensor(owner, proto_tensor), mutable_proto_tensor_(proto_tensor) {} ~MutableTensor() = default; // Bytes type: for images etc. Status SetBytesData(const std::vector &val); Status SetStrData(const std::string &val); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const std::vector &val, const std::vector &shape); Status SetData(const void *data, size_t data_bytes_len, const std::vector &shape, DataType data_type); private: proto::Tensor *mutable_proto_tensor_; }; class MS_API Instance { public: Instance(const ProtoMsgOwner &owner, const proto::Instance *proto_instance, const proto::ErrorMsg *error_msg) : message_owner_(owner), proto_instance_(proto_instance), error_msg_(error_msg) {} virtual ~Instance() = default; Tensor Get(const std::string &item_name) const; bool IsValid() const { return proto_instance_ != nullptr; } bool HasErrorMsg(int64_t *error_code, std::string *error_msg) const; protected: ProtoMsgOwner message_owner_; private: const proto::Instance *proto_instance_; const proto::ErrorMsg *error_msg_; }; class MS_API MutableInstance : public Instance { public: MutableInstance(const ProtoMsgOwner &owner, proto::Instance *proto_instance) : Instance(owner, proto_instance, nullptr), mutable_proto_instance_(proto_instance) {} ~MutableInstance() = default; MutableTensor Add(const std::string &item_name); private: proto::Instance *mutable_proto_instance_; }; class MS_API InstancesRequest { public: InstancesRequest(); ~InstancesRequest() = default; MutableInstance AddInstance(); private: std::shared_ptr request_ = nullptr; friend class Client; }; class MS_API InstancesReply { public: InstancesReply(); ~InstancesReply() = default; std::vector GetResult() const; private: std::shared_ptr reply_ = nullptr; friend class Client; }; class ClientImpl; class MS_API Client { public: Client(const std::string &server_ip, uint64_t server_port, const std::string &servable_name, const std::string &method_name, uint64_t version_number = 0); ~Client() = default; Status SendRequest(const InstancesRequest &request, InstancesReply *reply); private: std::string server_ip_; uint64_t server_port_; std::string servable_name_; std::string method_name_; uint64_t version_number_ = 0; std::shared_ptr impl_; }; template Status &Status::operator<<(T val) { std::stringstream stringstream; stringstream << val; status_msg_ += stringstream.str(); return *this; } template Status &Status::operator<<(const std::vector &val) { operator<<("["); for (size_t i = 0; i < val.size(); i++) { operator<<(val[i]); if (i != val.size() - 1) { operator<<(", "); } } operator<<("["); return *this; } } // namespace client } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_CLIENT_H ================================================ FILE: mindspore_serving/client/python/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: mindspore_serving/client/python/client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore Serving Client""" import grpc import numpy as np import mindspore_serving.proto.ms_service_pb2 as ms_service_pb2 import mindspore_serving.proto.ms_service_pb2_grpc as ms_service_pb2_grpc def _create_tensor(data, tensor=None): """Create tensor from numpy data""" if tensor is None: tensor = ms_service_pb2.Tensor() tensor.shape.dims.extend(data.shape) dtype_map = { np.bool: ms_service_pb2.MS_BOOL, np.int8: ms_service_pb2.MS_INT8, np.uint8: ms_service_pb2.MS_UINT8, np.int16: ms_service_pb2.MS_INT16, np.uint16: ms_service_pb2.MS_UINT16, np.int32: ms_service_pb2.MS_INT32, np.uint32: ms_service_pb2.MS_UINT32, np.int64: ms_service_pb2.MS_INT64, np.uint64: ms_service_pb2.MS_UINT64, np.float16: ms_service_pb2.MS_FLOAT16, np.float32: ms_service_pb2.MS_FLOAT32, np.float64: ms_service_pb2.MS_FLOAT64, } for k, v in dtype_map.items(): if k == data.dtype: tensor.dtype = v break if tensor.dtype == ms_service_pb2.MS_UNKNOWN: raise RuntimeError("Unknown data type " + str(data.dtype)) tensor.data = data.tobytes() return tensor def _create_scalar_tensor(vals, tensor=None): """Create tensor from scalar data""" if not isinstance(vals, (tuple, list)): vals = (vals,) return _create_tensor(np.array(vals), tensor) def _create_bytes_tensor(bytes_vals, tensor=None): """Create tensor from bytes data""" if tensor is None: tensor = ms_service_pb2.Tensor() if not isinstance(bytes_vals, (tuple, list)): bytes_vals = (bytes_vals,) tensor.shape.dims.extend([len(bytes_vals)]) tensor.dtype = ms_service_pb2.MS_BYTES for item in bytes_vals: tensor.bytes_val.append(item) return tensor def _create_str_tensor(str_vals, tensor=None): """Create tensor from str data""" if tensor is None: tensor = ms_service_pb2.Tensor() if not isinstance(str_vals, (tuple, list)): str_vals = (str_vals,) tensor.shape.dims.extend([len(str_vals)]) tensor.dtype = ms_service_pb2.MS_STRING for item in str_vals: tensor.bytes_val.append(bytes(item, encoding="utf8")) return tensor def _create_numpy_from_tensor(tensor): """Create numpy from protobuf tensor""" dtype_map = { ms_service_pb2.MS_BOOL: np.bool, ms_service_pb2.MS_INT8: np.int8, ms_service_pb2.MS_UINT8: np.uint8, ms_service_pb2.MS_INT16: np.int16, ms_service_pb2.MS_UINT16: np.uint16, ms_service_pb2.MS_INT32: np.int32, ms_service_pb2.MS_UINT32: np.uint32, ms_service_pb2.MS_INT64: np.int64, ms_service_pb2.MS_UINT64: np.uint64, ms_service_pb2.MS_FLOAT16: np.float16, ms_service_pb2.MS_FLOAT32: np.float32, ms_service_pb2.MS_FLOAT64: np.float64, } if tensor.dtype == ms_service_pb2.MS_STRING or tensor.dtype == ms_service_pb2.MS_BYTES: result = [] for item in tensor.bytes_val: if tensor.dtype == ms_service_pb2.MS_STRING: result.append(bytes.decode(item)) else: result.append(item) if len(result) == 1: return result[0] return result result = np.frombuffer(tensor.data, dtype_map[tensor.dtype]).reshape(tensor.shape.dims) return result def _check_str(arg_name, str_val): """Check whether the input parameters are reasonable str input""" if not isinstance(str_val, str): raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}") if not str_val: raise RuntimeError(f"Parameter '{arg_name}' should not be empty str") def _check_int(arg_name, int_val, minimum=None, maximum=None): """Check whether the input parameters are reasonable int input""" if not isinstance(int_val, int): raise RuntimeError(f"Parameter '{arg_name}' should be int, but actually {type(int_val)}") if minimum is not None and int_val < minimum: if maximum is not None: raise RuntimeError(f"Parameter '{arg_name}' should be in range [{minimum},{maximum}]") raise RuntimeError(f"Parameter '{arg_name}' should be >= {minimum}") if maximum is not None and int_val > maximum: if minimum is not None: raise RuntimeError(f"Parameter '{arg_name}' should be in range [{minimum},{maximum}]") raise RuntimeError(f"Parameter '{arg_name}' should be <= {maximum}") class SSLConfig: """ The client's ssl_config encapsulates grpc's ssl channel credentials for SSL-enabled connections. Args: certificate (str, optional): File holding the PEM-encoded certificate chain as a byte string to use or ``None`` if no certificate chain should be used. Default: ``None``. private_key (str, optional): File holding the PEM-encoded private key as a byte string, or ``None`` if no private key should be used. Default: ``None``. custom_ca (str, optional): File holding the PEM-encoded root certificates as a byte string, or ``None`` to retrieve them from a default location chosen by gRPC runtime. Default: ``None``. Raises: RuntimeError: The type or value of the parameters is invalid. """ def __init__(self, certificate=None, private_key=None, custom_ca=None): if certificate is not None: _check_str("certificate", certificate) if private_key is not None: _check_str("private_key", private_key) if custom_ca is not None: _check_str("custom_ca", custom_ca) self.certificate = certificate self.private_key = private_key self.custom_ca = custom_ca class Client: """ The Client encapsulates the serving gRPC API, which can be used to create requests, access serving, and parse results. Note: The maximum amount of data that the client can send in one request is 512MB, and the maximum amount of data that the server can accept can be configured as 1~512MB, 100MB by default. Args: address (str): Serving address. servable_name (str): The name of servable supplied by Serving. method_name (str): The name of method supplied by servable. version_number (int, optional): The version number of servable, ``0`` means the maximum version number in all running versions. Default: ``0``. ssl_config (mindspore_serving.client.SSLConfig, optional): The server's ssl_config, if ``None``, disabled ssl. Default: ``None``. Raises: RuntimeError: The type or value of the parameters are invalid, or other errors happened. Examples: >>> from mindspore_serving.client import Client >>> import numpy as np >>> client = Client("localhost:5500", "add", "add_cast") >>> instances = [] >>> x1 = np.ones((2, 2), np.int32) >>> x2 = np.ones((2, 2), np.int32) >>> instances.append({"x1": x1, "x2": x2}) >>> result = client.infer(instances) >>> print(result) """ def __init__(self, address, servable_name, method_name, version_number=0, ssl_config=None): _check_str("address", address) _check_str("servable_name", servable_name) _check_str("method_name", method_name) _check_int("version_number", version_number, 0) self.address = address self.servable_name = servable_name self.method_name = method_name self.version_number = version_number msg_bytes_size = 512 * 1024 * 1024 # 512MB options = [ ('grpc.max_send_message_length', msg_bytes_size), ('grpc.max_receive_message_length', msg_bytes_size), ] if ssl_config is not None: if not isinstance(ssl_config, SSLConfig): raise RuntimeError("The type of ssl_config should be type of SSLConfig") rc_bytes = pk_bytes = c_bytes = None if ssl_config.certificate is not None: with open(ssl_config.certificate, 'rb') as c_fs: c_bytes = c_fs.read() if ssl_config.private_key is not None: with open(ssl_config.private_key, 'rb') as pk_fs: pk_bytes = pk_fs.read() if ssl_config.custom_ca is not None: with open(ssl_config.custom_ca, 'rb') as rc_fs: rc_bytes = rc_fs.read() if (c_bytes is None and pk_bytes is not None) or (c_bytes is not None and pk_bytes is None): raise RuntimeError("The certificate and private_key should be passed at the same time") creds = grpc.ssl_channel_credentials(root_certificates=rc_bytes, private_key=pk_bytes, certificate_chain=c_bytes) self.channel = grpc.secure_channel(address, creds, options=options) else: self.channel = grpc.insecure_channel(address, options=options) self.stub = ms_service_pb2_grpc.MSServiceStub(self.channel) def infer(self, instances): """ Used to create requests, access serving service, and parse and return results. Args: instances (Union[dict, tuple[dict]]): Instance or tuple of instances, every instance item is the inputs dict. The key is the input name, and the value is the input value, the type of value can be python int, float, bool, str, bytes, numpy number, or numpy array object. Raises: RuntimeError: The type or value of the parameters is invalid, or other errors happened. Examples: >>> from mindspore_serving.client import Client >>> import numpy as np >>> client = Client("localhost:5500", "add", "add_cast") >>> instances = [] >>> x1 = np.ones((2, 2), np.int32) >>> x2 = np.ones((2, 2), np.int32) >>> instances.append({"x1": x1, "x2": x2}) >>> result = client.infer(instances) >>> print(result) """ request = self._create_request(instances) try: result = self.stub.Predict(request) return self._paser_result(result) except grpc.RpcError as e: print(e.details()) status_code = e.code() print(status_code.name) print(status_code.value) return {"error": f"Grpc Error, {status_code.value}, {e.details()}"} def infer_async(self, instances): """ Used to create requests, async access serving. Args: instances (Union[dict, tuple[dict]]): Instance or tuple of instances, every instance item is the inputs dict. The key is the input name, and the value is the input value, the type of value can be python int, float, bool, str, bytes, numpy number, or numpy array object. Raises: RuntimeError: The type or value of the parameters is invalid, or other errors happened. Examples: >>> from mindspore_serving.client import Client >>> import numpy as np >>> client = Client("localhost:5500", "add", "add_cast") >>> instances = [] >>> x1 = np.ones((2, 2), np.int32) >>> x2 = np.ones((2, 2), np.int32) >>> instances.append({"x1": x1, "x2": x2}) >>> result_future = client.infer_async(instances) >>> result = result_future.result() >>> print(result) """ request = self._create_request(instances) try: result_future = self.stub.Predict.future(request) return ClientGrpcAsyncResult(result_future) except grpc.RpcError as e: print(e.details()) status_code = e.code() print(status_code.name) print(status_code.value) return ClientGrpcAsyncError({"error": f"Grpc Error, {status_code.value}, {e.details()}"}) def _create_request(self, instances): """Used to create request spec.""" if not isinstance(instances, (tuple, list)): instances = (instances,) request = ms_service_pb2.PredictRequest() request.servable_spec.name = self.servable_name request.servable_spec.method_name = self.method_name request.servable_spec.version_number = self.version_number for item in instances: if isinstance(item, dict): request.instances.append(self._create_instance(**item)) else: raise RuntimeError("instance should be a map") return request @staticmethod def _create_instance(**kwargs): """Used to create gRPC instance.""" instance = ms_service_pb2.Instance() for k, w in kwargs.items(): tensor = instance.items[k] if isinstance(w, (np.ndarray, np.number)): _create_tensor(w, tensor) elif isinstance(w, str): _create_str_tensor(w, tensor) elif isinstance(w, (bool, int, float)): _create_scalar_tensor(w, tensor) elif isinstance(w, bytes): _create_bytes_tensor(w, tensor) else: raise RuntimeError("Not support value type " + str(type(w))) return instance @staticmethod def _paser_result(result): """Used to parse result.""" error_msg_len = len(result.error_msg) if error_msg_len == 1 and result.error_msg[0].error_code != 0: return {"error": bytes.decode(result.error_msg[0].error_msg)} ret_val = [] instance_len = len(result.instances) if error_msg_len not in (0, instance_len): raise RuntimeError(f"error msg result size {error_msg_len} not be 0, 1 or " f"length of instances {instance_len}") for i in range(instance_len): instance = result.instances[i] if error_msg_len == 0 or result.error_msg[i].error_code == 0: instance_map = {} for k, w in instance.items.items(): instance_map[k] = _create_numpy_from_tensor(w) ret_val.append(instance_map) else: ret_val.append({"error": bytes.decode(result.error_msg[i].error_msg)}) return ret_val class ClientGrpcAsyncResult: """ When Client.infer_async invoke successfully, a ClientGrpcAsyncResult object is returned. Examples: >>> from mindspore_serving.client import Client >>> import numpy as np >>> client = Client("localhost:5500", "add", "add_cast") >>> instances = [] >>> x1 = np.ones((2, 2), np.int32) >>> x2 = np.ones((2, 2), np.int32) >>> instances.append({"x1": x1, "x2": x2}) >>> result_future = client.infer_async(instances) >>> result = result_future.result() >>> print(result) """ def __init__(self, result_future): self.result_future = result_future def result(self): """Wait and get result of inference result, the gRPC message will be parse to tuple of instances result. Every instance result is dict, and value could be numpy array/number, str or bytes according gRPC Tensor data type. """ result = self.result_future.result() # pylint: disable=protected-access result = Client._paser_result(result) return result class ClientGrpcAsyncError: """When gRPC failed happened when calling Client.infer_async, a ClientGrpcAsyncError object is returned. """ def __init__(self, result_error): self.result_error = result_error def result(self): """Get gRPC error message. """ return self.result_error ================================================ FILE: mindspore_serving/log.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ log module """ import sys import os import stat import time import logging from logging.handlers import RotatingFileHandler import traceback import threading import platform if platform.system() != "Windows": import fcntl __all__ = ['get_level', 'get_log_config'] # The lock for setting up the logger _setup_logger_lock = threading.Lock() # When getting the logger, Used to check whether # the logger already exists _global_logger = None # The flag for enable console output _std_on = '1' # The flag for disable console output _std_off = '0' # Rotating max bytes, default is 50M _logger_def_max_bytes = '52428800' # Rotating backup count, default is 30 _logger_def_backup_count = '30' # The default log level _logger_def_level = '2' # Log level name and level mapping _name_to_level = { 'ERROR': 40, 'WARNING': 30, 'INFO': 20, 'DEBUG': 10, } # GLog level and level name _gloglevel_to_name = { '3': 'ERROR', '2': 'WARNING', '1': 'INFO', '0': 'DEBUG', } # The mapping of logger configurations to glog configurations _confmap_dict = {'level': 'GLOG_v', 'console': 'GLOG_logtostderr', 'filepath': 'GLOG_log_dir', 'maxBytes': 'logger_maxBytes', 'backupCount': 'logger_backupCount', 'stderr_level': 'GLOG_stderrthreshold'} class _MultiCompatibleRotatingFileHandler(RotatingFileHandler): """Inherit RotatingFileHandler for multiprocess compatibility.""" def rolling_rename(self): """Rolling rename log files and set permission of Log file""" for i in range(self.backupCount - 1, 0, -1): sfn = self.rotation_filename("%s.%d" % (self.baseFilename, i)) dfn = self.rotation_filename("%s.%d" % (self.baseFilename, i + 1)) if os.path.exists(sfn): if os.path.exists(dfn): os.remove(dfn) # Modify the permission of Log file os.chmod(sfn, stat.S_IREAD) os.rename(sfn, dfn) def doRollover(self): """Override doRollover for multiprocess compatibility and setting permission of Log file. """ if self.stream: self.stream.close() self.stream = None # Attain an exclusive lock with blocking mode by `fcntl` module. with open(self.baseFilename, 'a') as file_pointer: if platform.system() != "Windows": fcntl.lockf(file_pointer.fileno(), fcntl.LOCK_EX) if self.backupCount > 0: self.rolling_rename() dfn = self.rotation_filename(self.baseFilename + ".1") if os.path.exists(dfn): os.remove(dfn) # Modify the permission of Log file os.chmod(self.baseFilename, stat.S_IREAD) self.rotate(self.baseFilename, dfn) with open(self.baseFilename, 'a'): # Modify the permission of Log file os.chmod(self.baseFilename, stat.S_IREAD | stat.S_IWRITE) if not self.delay: self.stream = self._open() class _DataFormatter(logging.Formatter): """Log formatter""" def __init__(self, sub_module, fmt=None, **kwargs): """ Initialization of logFormatter. Args: sub_module (str): The submodule name. fmt (str): Specified format pattern. Default: None. """ super(_DataFormatter, self).__init__(fmt=fmt, **kwargs) self.sub_module = sub_module.upper() def formatTime(self, record, datefmt=None): """ Override formatTime for uniform format %Y-%m-%d-%H:%M:%S.SSS.SSS Args: record (str): Log record. datefmt (str): Date format. Returns: str, formatted timestamp. """ created_time = self.converter(record.created) if datefmt: return time.strftime(datefmt, created_time) timestamp = time.strftime('%Y-%m-%d-%H:%M:%S', created_time) msecs = str(round(record.msecs * 1000)) # Format the time stamp return f'{timestamp}.{msecs[:3]}.{msecs[3:]}' def format(self, record): """ Apply log format with specified pattern. Args: record (str): Format pattern. Returns: str, formatted log content according to format pattern. """ # NOTICE: when the Installation directory of mindspore changed, # ms_home_path must be changed ms_install_home_path = 'mindspore' idx = record.pathname.rfind(ms_install_home_path) if idx >= 0: # Get the relative path of the file record.filepath = record.pathname[idx:] else: record.filepath = record.pathname record.sub_module = self.sub_module return super().format(record) def _get_logger(): """ Get logger instance. Returns: Logger, a logger. """ if _global_logger: return _global_logger kwargs = _get_env_config() _verify_config(kwargs) logger = _setup_logger(_adapt_cfg(kwargs)) return logger def _adapt_cfg(kwargs): """ Glog configurations converted to logger configurations. Args: kwargs (dict): The dictionary of log configurations. - console (str): Whether to output log to stdout. - level (str): Log level. - filepath (str): The path for saving logs, if console is false, a file path must be assigned. - maxBytes (str): The Maximum value of a log file for rotating, only valid if console is false. - backupCount (str): The count of rotating backup log files, only valid if console is false. Returns: Dict, the input parameter dictionary. """ kwargs['level'] = _gloglevel_to_name.get(kwargs.get('level', _logger_def_level)) kwargs['stderr_level'] = _gloglevel_to_name.get(kwargs.get('stderr_level', _logger_def_level)) kwargs['console'] = not kwargs.get('console') == _std_off kwargs['maxBytes'] = int(kwargs.get('maxBytes', _logger_def_max_bytes)) kwargs['backupCount'] = int(kwargs.get('backupCount', _logger_def_backup_count)) return kwargs def info(msg, *args, **kwargs): """ Log a message with severity 'INFO' on the MindSpore logger. Examples: >>> from mindspore_serving import log as logger >>> logger.info("The arg(%s) is: %r", name, arg) """ _get_logger().info(msg, *args, **kwargs) def debug(msg, *args, **kwargs): """ Log a message with severity 'DEBUG' on the MindSpore logger. Examples: >>> from mindspore_serving import log as logger >>> logger.debug("The arg(%s) is: %r", name, arg) """ _get_logger().debug(msg, *args, **kwargs) def error(msg, *args, **kwargs): """Log a message with severity 'ERROR' on the MindSpore logger.""" _get_logger().error(msg, *args, **kwargs) def warning(msg, *args, **kwargs): """Log a message with severity 'WARNING' on the MindSpore logger.""" _get_logger().warning(msg, *args, **kwargs) def get_level(): """ Get the logger level. Returns: str, the Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG). Examples: >>> import os >>> os.environ['GLOG_v'] = '0' >>> from mindspore_serving import log as logger >>> logger.get_level() """ # level and glog level mapping dictionary level_to_glog_level = dict(zip(_name_to_level.values(), _gloglevel_to_name.keys())) return level_to_glog_level.get(_get_logger().getEffectiveLevel()) def _get_formatter(): """ Get the string of log formatter. Returns: str, the string of log formatter. """ formatter = '[%(levelname)s] %(sub_module)s(%(process)d:' \ '%(thread)d,%(processName)s):%(asctime)s ' \ '[%(filepath)s:%(lineno)d] %(message)s' return formatter def _get_env_config(): """ Get configurations from environment variables. Returns: Dict, the dictionary of configurations. """ config_dict = {} for key, env_value in _confmap_dict.items(): value = os.environ.get(env_value) if value: config_dict[key] = value.strip() return config_dict def _verify_config(kwargs): """ Verify log configurations. Args: kwargs (dict): The dictionary of log configurations. - console (str): Whether to output log to stdout. - level (str): Log level. - filepath (str): The path for saving logs, if console is false, a file path must be assigned. - maxBytes (str): The Maximum value of a log file for rotating, only valid if console is false. - backupCount (str): The count of rotating backup log files, only valid if console is false. """ # Check the input value of level level = kwargs.get('level', None) if level is not None: _verify_level(level) # Check the input value of stderr_level level = kwargs.get('stderr_level', None) if level is not None: _verify_level(level) # Check the input value of console console = kwargs.get('console', None) file_path = kwargs.get('filepath', None) if console is not None: if not console.isdigit() or console not in (_std_off, _std_on): raise ValueError(f'Incorrect value, The value of {_confmap_dict["console"]} must be 0 or 1,' f' Output log to console, configure to 1.') if console == _std_off and not file_path: raise ValueError(f'When {_confmap_dict["console"]} is set to 0, The directory of ' f'saving log must be set, {_confmap_dict["filepath"]} cannot be empty.') # Check the input value of filepath if console == _std_off and file_path is not None: file_real_path = os.path.realpath(file_path) if not os.path.exists(file_real_path): raise ValueError(f'The file path does not exist. ' f'{_confmap_dict["filepath"]}:{file_path}') # Check the input value of maxBytes max_bytes = kwargs.get('maxBytes', None) if console == _std_off and max_bytes is not None: if not max_bytes.isdigit(): raise ValueError(f'Incorrect value, The value of {_confmap_dict["maxBytes"]} must be positive integer. ' f'{_confmap_dict["maxBytes"]}:{max_bytes}') # Check the input value of backupCount backup_count = kwargs.get('backupCount', None) if console == _std_off and backup_count is not None: if not backup_count.isdigit(): raise ValueError(f'Incorrect value, The value of {_confmap_dict["backupCount"]} must be positive ' f'integer. {_confmap_dict["backupCount"]}:{backup_count}') def _verify_level(level): """ Verify log level. Args: level (str): The log level. """ level_name = _gloglevel_to_name.get(level, None) # Check the value of input level if level_name not in _name_to_level: raise ValueError(f'Incorrect log level:{level}, Please check the configuration of GLOG_v or ' f'GLOG_stderrthreshold, desired log level :{_gloglevel_to_name}') def get_log_config(): """ Get logger configurations. Returns: Dict, the dictionary of logger configurations. Examples: >>> import os >>> os.environ['GLOG_v'] = '1' >>> os.environ['GLOG_logtostderr'] = '0' >>> os.environ['GLOG_log_dir'] = '/var/log/mindspore' >>> os.environ['logger_maxBytes'] = '5242880' >>> os.environ['logger_backupCount'] = '10' >>> from mindspore_serving import log as logger >>> logger.get_log_config() """ logger = _get_logger() handler = logger.handlers[0] config_dict = {} config_dict['GLOG_v'] = get_level() config_dict['GLOG_logtostderr'] = _std_on if handler.name == 'FileHandler': config_dict['GLOG_logtostderr'] = _std_off # Separating file path and name file_path_and_name = os.path.split(handler.baseFilename) config_dict['GLOG_log_dir'] = file_path_and_name[0] config_dict['logger_maxBytes'] = handler.maxBytes config_dict['logger_backupCount'] = handler.backupCount handler_stderr = logger.handlers[1] # level and glog level mapping dictionary level_to_glog_level = dict(zip(_name_to_level.values(), _gloglevel_to_name.keys())) config_dict['GLOG_stderrthreshold'] = level_to_glog_level.get(handler_stderr.level) return config_dict def _clear_handler(logger): """Clear the handlers that has been set, avoid repeated loading""" for handler in logger.handlers: logger.removeHandler(handler) def _find_caller(stack_info=False, _=1): """ Find the stack frame of the caller. Override findCaller on the logger, Support for getting log record. Find the stack frame of the caller so that we can note the source file name, function name and line number. Args: stack_info (bool): If the value is true, print stack information to the log. Default: False. Returns: tuple, the tuple of the frame data. """ # pylint: disable=protected-access f = sys._getframe(3) sinfo = None # log_file is used to check caller stack frame log_file = os.path.normcase(f.f_code.co_filename) f = f.f_back rv = "(unknown file)", 0, "(unknown function)", None while f: co = f.f_code filename = os.path.normcase(co.co_filename) if log_file == filename: f = f.f_back continue if stack_info: sinfo = _get_stack_info(f) rv = (co.co_filename, f.f_lineno, co.co_name, sinfo) break return rv def _get_stack_info(frame): """ Get the stack information. Args: frame(frame): the frame requiring information. Returns: str, the string of the stack information. """ sinfo = None stack_prefix = 'Stack (most recent call last):\n' sinfo = stack_prefix + "".join(traceback.format_stack(frame)) return sinfo def _setup_logger(kwargs): """ Set up the logger. Args: kwargs (dict): The dictionary of log configurations. - console (bool): Whether to output log to stdout. Default: True. - level (str): Log level. Default: WARNING. - filepath (str): The path for saving logs, if console is false, a file path must be assigned. - maxBytes (int): The Maximum value of a log file for rotating, only valid if console is false. Default: 52428800. - backupCount (int): The count of rotating backup log files, only valid if console is false. Default: 30. Returns: Logger, well-configured logger. """ # The name of Submodule sub_module = 'SERVING' # The name of Base log file pid = str(os.getpid()) log_name = 'mindspore_serving.log.' + pid global _global_logger _setup_logger_lock.acquire() try: if _global_logger: return _global_logger logger = logging.getLogger(name=f'{sub_module}.{log_name}') # Override findCaller on the logger, Support for getting log record logger.findCaller = _find_caller console = kwargs.get('console', True) # Set log level logger.setLevel(kwargs.get('level', logging.WARNING)) # Set "propagate" attribute to False, stop searching up the hierarchy, # avoid to load the handler of the root logger logger.propagate = False # Get the formatter for handler formatter = _get_formatter() # Clean up handle to avoid repeated loading _clear_handler(logger) # Set streamhandler for the console appender if console: console_handler = logging.StreamHandler(sys.stderr) console_handler.name = 'StreamHandler' console_handler.formatter = _DataFormatter(sub_module, formatter) logger.addHandler(console_handler) # Set rotatingFileHandler for the file appender else: # filepath cannot be null, checked in function _verify_config () logfile_dir = os.path.realpath(kwargs.get('filepath')) file_name = f'{logfile_dir}/{log_name}' logfile_handler = _MultiCompatibleRotatingFileHandler( filename=file_name, # Rotating max bytes, default is 50M maxBytes=kwargs.get('maxBytes', _logger_def_max_bytes), # Rotating backup count, default is 30 backupCount=kwargs.get('backupCount', _logger_def_backup_count), encoding='utf8' ) logfile_handler.name = 'FileHandler' logfile_handler.formatter = _DataFormatter(sub_module, formatter) logger.addHandler(logfile_handler) # Write the file and output warning and error logs to stderr console_handler = logging.StreamHandler(sys.stderr) console_handler.name = 'StreamHandler' console_handler.formatter = _DataFormatter(sub_module, formatter) console_handler.setLevel(kwargs.get('stderr_level', logging.WARNING)) logger.addHandler(console_handler) _global_logger = logger finally: _setup_logger_lock.release() return _global_logger ================================================ FILE: mindspore_serving/proto/ms_agent.proto ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // ms_manager.proto syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; message DistributedPredictRequest { repeated Tensor inputs = 1; bool return_result = 2; int64 subgraph = 3; } message DistributedPredictReply { repeated Tensor outputs = 1; ErrorMsg error_msg = 2; } message DistributedExitRequest { string address = 1; } message DistributedExitReply { ErrorMsg error_msg = 1; } service MSAgent { rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} rpc Ping(PingRequest) returns (PingReply) {} rpc Pong(PongRequest) returns (PongReply) {} } ================================================ FILE: mindspore_serving/proto/ms_distributed.proto ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // ms_manager.proto syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; message AgentSpec { int64 rank_id = 1; int64 batch_size = 2; repeated TensorInfo inputs = 3; repeated TensorInfo outputs = 4; } message CommonModelMeta { string servable_name = 1; string model_key = 2; bool with_batch_dim = 3; repeated int64 without_batch_dim_inputs = 4; map inputs_count = 5; map outputs_count = 6; } message DistributedModelMeta { int64 rank_size = 1; int64 stage_size = 2; } message AgentRegisterRequest { repeated AgentSpec agent_spec = 1; string address = 2; } message AgentRegisterReply { ErrorMsg error_msg = 1; } message AgentExitRequest { oneof address_choice { string address = 1; // by agent process string agent_ip = 2; // by agent start up process } } message AgentExitReply { ErrorMsg error_msg = 1; } message AgentFailedRequest { } message AgentFailedReply { ErrorMsg error_msg = 1; } message AgentConfigAcquireRequest { } message AgentConfigAcquireReply { message OneRankConfig { string ip = 1; int64 device_id = 2; } string rank_table_content = 1; repeated OneRankConfig rank_list = 2; CommonModelMeta common_meta = 3; DistributedModelMeta distributed_meta = 4; } ================================================ FILE: mindspore_serving/proto/ms_master.proto ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // ms_manager.proto syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; service MSMaster { rpc Register(RegisterRequest) returns (RegisterReply) {} rpc Exit(ExitRequest) returns (ExitReply) {} rpc NotifyFailed(NotifyFailedRequest) returns (NotifyFailedReply) {} rpc CallModel(PredictRequest) returns (PredictReply) {} rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoReply) {} } message ServableRegSpec { string name = 1; uint64 version_number = 2; uint64 batch_size = 4; message MethodInfo{ string name = 1; repeated string input_names = 2; bool only_model_stage = 3; } repeated MethodInfo methods = 5; ModelInfos model_infos = 6; // model key, bool own_device = 7; } message WorkerRegSpec { uint64 worker_pid = 1; string address = 2; ServableRegSpec servable_spec = 4; } message RegisterRequest { WorkerRegSpec worker_spec = 1; } message RegisterReply { ErrorMsg error_msg = 1; } message ExitRequest { string address = 1; } message ExitReply { ErrorMsg error_msg = 1; } message NotifyFailedRequest { uint64 worker_pid = 1; string error_msg = 2; } message NotifyFailedReply { } message GetModelInfoRequest { string servable_name = 1; uint32 version_number = 2; } message ModelSubGraphInfo { repeated TensorInfo inputs = 3; repeated TensorInfo outputs = 4; } message ModelInfo { uint64 batch_size = 2; repeated ModelSubGraphInfo subgraph_infos = 1; } message ModelInfos { map model_infos = 1; // model key, } message GetModelInfoReply { string servable_name = 1; uint32 version_number = 2; ModelInfos model_infos = 3; ErrorMsg error_msg = 4; } ================================================ FILE: mindspore_serving/proto/ms_service.proto ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // ms_service.proto syntax = "proto3"; package mindspore.serving.proto; service MSService { rpc Predict(PredictRequest) returns (PredictReply) {} } message PredictRequest { ServableSpec servable_spec = 1; repeated Instance instances = 2; } message ErrorMsg{ int64 error_code = 1; // 0 is valid, otherwise invalid bytes error_msg = 2; } message PredictReply { ServableSpec servable_spec = 1; repeated Instance instances = 3; // size 0: OK, 1: for all batch, >1: for every batch repeated ErrorMsg error_msg = 4; } message Instance{ map items = 1; map output_buffers = 2; } enum DataType { MS_UNKNOWN = 0; MS_BOOL = 1; MS_INT8 = 2; MS_UINT8 = 3; MS_INT16 = 4; MS_UINT16 = 5; MS_INT32 = 6; MS_UINT32 = 7; MS_INT64 = 8; MS_UINT64 = 9; MS_FLOAT16 = 10; MS_FLOAT32 = 11; MS_FLOAT64 = 12; MS_STRING = 13; // for string model input MS_BYTES = 14; // for images } message TensorShape { repeated int64 dims = 1; }; message ShmTensorData { string memory_key = 1; uint64 bytes_size = 2; // the total shared memory size uint64 data_offset = 3; uint64 data_size = 4; } message Tensor { // tensor shape info TensorShape shape = 1; // tensor content data type DataType dtype = 2; // tensor data oneof tensor_data { bytes data = 3; ShmTensorData shm_data = 5; } // for string type and images, the dtype is MS_BYTES. repeated bytes bytes_val = 4; } message ServableSpec { // servable name string name = 1; // optional. If unspecified, the latest version servable will be used. uint64 version_number = 3; // Specifies the method name in the servable. string method_name = 2; } message PingRequest { string address = 1; } message PingReply { string address = 1; } message PongRequest { string address = 1; } message PongReply { string address = 1; } message TensorInfo { TensorShape shape = 1; // tensor shape info DataType dtype = 2; // tensor content data type uint64 size = 3; bool is_no_batch_dim = 4; } ================================================ FILE: mindspore_serving/proto/ms_worker.proto ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // ms_manager.proto syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; import "mindspore_serving/proto/ms_master.proto"; import "mindspore_serving/proto/ms_distributed.proto"; service MSWorker { // for master rpc Predict(PredictRequest) returns (PredictReply) {} rpc Exit(ExitRequest) returns (ExitReply) {} } service MSDistributedWorker { // for worker agent rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} rpc AgentConfigAcquire(AgentConfigAcquireRequest) returns (AgentConfigAcquireReply) {} rpc Ping(PingRequest) returns (PingReply) {} rpc Pong(PongRequest) returns (PongReply) {} } ================================================ FILE: mindspore_serving/server/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ MindSpore Serving is a lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment. MindSpore Serving server API, which can be used to start servables, gRPC and RESTful server. A servable corresponds to the service provided by a model. The client sends inference tasks and receives inference results through gRPC and RESTful server. """ from .master import start_grpc_server, start_restful_server, stop, SSLConfig from ._server import start_servables, ServableStartConfig from . import register from . import distributed __all__ = [] __all__.extend([ "start_grpc_server", "start_restful_server", "stop", "start_servables", 'ServableStartConfig', "SSLConfig" ]) __all__.extend(register.__all__) __all__.extend(distributed.__all__) ================================================ FILE: mindspore_serving/server/_servable_common.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Interface for start up servable""" import os import time import threading import signal import psutil import mindspore_serving.log as logger from mindspore_serving._mindspore_serving import WorkerContext_ class ServableContextDataBase: """Used to startup servable process""" def __init__(self): pass @property def servable_name(self): raise NotImplementedError @property def version_number(self): raise NotImplementedError def to_string(self): """For logging""" raise NotImplementedError def new_worker_process(self): """Start worker process to provide servable""" raise NotImplementedError def can_restart(self): """Whether the worker can restart""" return True def own_device(self): """Whether the worker occupy device""" return True class WorkerContext: """Used to monitor and manage workers""" def __init__(self, context_data, master_address, sub_process): if not isinstance(context_data, ServableContextDataBase): raise RuntimeError(f"Parameter '{context_data}' should be instance of ServableReprInfo, " f"but actually {type(context_data)}") self.context_data_ = context_data self.master_address_ = master_address self.sub_process_ = sub_process self.last_not_alive_time_ = None self.is_in_process_switching_ = False self.context = WorkerContext_.init_worker(context_data.servable_name, context_data.version_number, context_data.to_string(), sub_process.pid) @property def servable_name(self): return self.context_data_.servable_name @property def worker_pid(self): return self.sub_process_.pid @property def master_address(self): return self.master_address_ def to_string(self): """For logging""" return f"{self.context_data_.to_string()}, pid: {self.worker_pid}" @property def is_in_process_switching(self): return self.is_in_process_switching_ def own_device(self): return self.context_data_.own_device() def ready(self): """Is worker ready to provide service""" return self.context.ready() def print_status(self): """DEBUG, used to print worker status""" self.context.print_status() def is_in_starting(self): """Whether the worker is in the process of startup""" return self.context.is_in_starting() def has_error_notified(self): """Whether error is reported by worker process during startup""" return self.context.has_error_notified() # Error message of worker notifying master def get_notified_error(self): return self.context.get_notified_error() def has_exit_notified(self): """Whether exit is reported by worker process""" return self.context.has_exit_notified() # Exit message of worker notifying master def can_be_restart(self): """Whether can restart the worker process""" if not self.context_data_.can_restart(): return False normal_handled_count = self.context.normal_handled_count return normal_handled_count > 0 def exit_for_enough_time(self): """ whether has exited for 1s, wait 1s for worker exit or error message""" return self.last_not_alive_time_ and (time.time() - self.last_not_alive_time_ > 1) def is_alive(self): """Whether the worker process is alive""" alive = (self.sub_process_.poll() is None) if not alive: if not self.last_not_alive_time_: self.context.notify_not_alive() self.last_not_alive_time_ = time.time() else: self.last_not_alive_time_ = None return alive def is_unavailable(self): """Whether the working process can link and provide services""" if self.is_in_process_switching: # restart: shutdown and start worker return False if self.is_in_starting(): # start worker return False return self.context.is_unavailable def update_worker_process(self, new_sub_process): """Update worker process pid""" self.context.update_worker_pid(new_sub_process.pid) self.sub_process_ = new_sub_process self.last_not_alive_time_ = None def _terminate(self): self.sub_process_.terminate() def _shutdown_worker(self): """Shutdown worker process""" if not self.is_alive(): return self._terminate() for _ in range(100): # 10s if not self.is_alive(): return time.sleep(0.1) self.send_exit_signal(signal.SIGKILL) self.context.notify_not_alive() def _restart_worker(self): """Restart worker process""" logger.info(f"restart worker, {self.to_string()}") self._shutdown_worker() try: new_sub_process = self.context_data_.new_worker_process() except RuntimeError as e: logger.error(f"Start worker failed: {e}") self.context.notify_start_failed(f"Start worker failed: {e}") return self.update_worker_process(new_sub_process) def shutdown_worker(self): """Shutdown worker process in thread""" self.handle_worker_process(self._shutdown_worker) def restart_worker(self): """Restart worker process in thread""" self.handle_worker_process(self._restart_worker) def handle_worker_process(self, thread_fun): """Used to do something in thread""" self.is_in_process_switching_ = True def handle_thread(): thread_fun() self.is_in_process_switching_ = False thread = threading.Thread(target=handle_thread) thread.start() def send_exit_signal(self, sig): """Send signal to worker process, used to exit the worker process""" if not self.is_alive(): return logger.warning(f"Send signal {sig} to worker, {self.to_string()}") try: child_process = psutil.Process(self.worker_pid) if not child_process.is_running(): return children_of_child = child_process.children(recursive=True) for item in children_of_child: os.kill(item.pid, sig) self.sub_process_.send_signal(sig) except psutil.NoSuchProcess: return except Exception as e: # pylint: disable=broad-except logger.warning(f"Get exception when send signal {sig} to worker, {self.to_string()}, " f"exception: {e}") ================================================ FILE: mindspore_serving/server/_servable_local.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Interface for start up single core servable""" import os import random import sys import subprocess from mindspore_serving import log as logger from mindspore_serving.server.common import check_type, get_abs_path from mindspore_serving.server.worker import get_newest_version_number from mindspore_serving.server._servable_common import ServableContextDataBase def _get_device_type(target_device_type, enable_lite): """Get device type supported, this will load libmindspore.so or libmindspore-lite.so""" # Get Device type: Ascend, Gpu, Cpu args = f"{sys.executable} -c 'from mindspore_serving._mindspore_serving import Worker_;" \ f"device_type=Worker_.get_device_type(\"{target_device_type}\", {enable_lite});" \ f"print(\"#get_device_type_result=\", device_type, \"#\", sep=\"\")'" process = subprocess.Popen(args=args, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process.wait() result = process.stdout.read().decode("utf-8") prefix = "#get_device_type_result=" index = result.find(prefix) if index < 0: raise RuntimeError(f"Failed to get device type") index += len(prefix) end_index = result.find("#", index) device_type = result[index:end_index] return device_type def _all_reuse_device(): """Get device type supported, this will load libmindspore.so or libmindspore-lite.so""" # Whether allow reuse device, for Ascend910 return False, other return True args = f"{sys.executable} -c 'from mindspore_serving._mindspore_serving import Worker_;" \ f"reuse_flag=Worker_.support_reuse_device();" \ f"print(\"#get_reuse_flag_result=\", reuse_flag, \"#\", sep=\"\")'" process = subprocess.Popen(args=args, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process.wait() result = process.stdout.read().decode("utf-8") prefix = "#get_reuse_flag_result=" index = result.find(prefix) if index < 0: raise RuntimeError(f"Failed to get device type") index += len(prefix) end_index = result.find("#", index) # pylint: disable=simplifiable-if-expression reuse_flag = True if result[index:end_index] == 'True' else False return reuse_flag class ServableStartConfig: r""" Servable startup configuration. For more detail, please refer to `MindSpore-based Inference Service Deployment `_ and `Servable Provided Through Model Configuration `_. Args: servable_directory (str): The directory where the servable is located in. There expects to has a directory named `servable_name`. servable_name (str): The servable name. device_ids (Union[int, list[int], tuple[int]], optional): The device list the model loads into and runs in. Used when device type is Nvidia GPU, Ascend 310P/910. Default None. version_number (int, optional): Servable version number to be loaded. The version number should be a positive integer, starting from 1, and 0 means to load the latest version. Default: 0. device_type (str, optional): Target device type for model deployment. Currently supports "Ascend", "GPU", "CPU" and None. Default: None. - "Ascend": the platform expected to be Ascend 310P/910, etc. - "GPU": the platform expected to be Nvidia GPU. - "CPU": the platform expected to be CPU. - None: the platform is determined by the MindSpore environment. num_parallel_workers (int, optional): The number of processes that process python tasks, at least the number of device cards used specified by the parameter device_ids. It will be adjusted to the number of device cards when it is less than the number of device cards. The value should be in range [0,64]. Default: 0. dec_key (bytes, optional): Byte type key used for decryption. The valid length is 16, 24, or 32. Default: None. dec_mode (str, optional): Specifies the decryption mode, take effect when dec_key is set. Option: 'AES-GCM' or 'AES-CBC'. Default: 'AES-GCM'. Raises: RuntimeError: The type or value of the parameters are invalid. """ def __init__(self, servable_directory, servable_name, device_ids=None, version_number=0, device_type=None, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM'): super(ServableStartConfig, self).__init__() check_type.check_str("servable_directory", servable_directory) logger.info(f"input servable directory: {servable_directory}") servable_directory = get_abs_path(servable_directory) logger.info(f"absolute servable directory: {servable_directory}") check_type.check_str("servable_name", servable_name) check_type.check_int("version_number", version_number, 0) check_type.check_int("num_parallel_workers", num_parallel_workers, 0, 64) if dec_key is not None: if not isinstance(dec_key, bytes): raise RuntimeError(f"Parameter 'dec_key' should be bytes, but actually {type(dec_key)}") if not dec_key: raise RuntimeError(f"Parameter 'dec_key' should not be empty bytes") if len(dec_key) not in (16, 24, 32): raise RuntimeError(f"Parameter 'dec_key' length {len(dec_key)} expected to be 16, 24 or 32") check_type.check_str("dec_mode", dec_mode) if dec_mode not in ('AES-GCM', 'AES-CBC'): raise RuntimeError(f"Parameter 'dec_mode' expected to be 'AES-GCM' or 'AES-CBC'") self.servable_directory_ = servable_directory self.servable_name_ = servable_name self.version_number_ = version_number if device_ids is None: device_ids = [] device_ids = check_type.check_and_as_int_tuple_list("device_ids", device_ids, 0) if device_type is not None: check_type.check_str("device_type", device_type) else: device_type = "None" if device_type.lower() != "none": if device_type.lower() not in ("ascend", "gpu", "cpu"): raise RuntimeError(f"Unsupported device type '{device_type}', only support 'Ascend', 'GPU', 'CPU' " f"and None, case ignored") # else device_type is None # if device_ids is empty, and there are models declared, Cpu target should be support # if device_ids is not empty, and there are no models declared, use no device resources # if device_ids is not empty, and there are models declared, final device_type depend on inference package self.device_ids_ = device_ids if not device_ids and not num_parallel_workers: self.num_parallel_workers_ = 1 else: self.num_parallel_workers_ = num_parallel_workers self.device_type_ = device_type.lower() self.dec_key_ = dec_key self.dec_mode_ = dec_mode @property def servable_directory(self): return self.servable_directory_ @property def servable_name(self): return self.servable_name_ @property def version_number(self): return self.version_number_ @property def device_type(self): return self.device_type_ @property def device_ids(self): return self.device_ids_ @property def dec_key(self): return self.dec_key_ @property def dec_mode(self): return self.dec_mode_ @property def num_parallel_workers(self): return self.num_parallel_workers_ def _check_device_type(self, enable_lite): """Check whether the device type is valid""" device_type = self.device_type_ if device_type.lower() != "none": if device_type.lower() not in ("ascend", "gpu", "cpu"): raise RuntimeError(f"Unsupported device type '{device_type}', only support 'Ascend', 'GPU', 'CPU' " f"and None, case ignored") default_device = _get_device_type(None, enable_lite) support_cpu = _get_device_type("cpu", enable_lite) if support_cpu and support_cpu != default_device: support_device = f"None, '{default_device}' or '{support_cpu}'" else: support_device = f"None or '{default_device}'" if not _get_device_type(device_type, enable_lite): raise RuntimeError(f"The device type '{device_type}' of servable name {self.servable_name} " f"is inconsistent with current running environment, supported device type: " f"{support_device}") class DeployConfig: """Deployment configuration of one version for the servable""" def __init__(self, version_number, device_ids, num_parallel_workers=0, dec_key=None, dec_mode='AES-GCM'): check_type.check_int("version_number", version_number) if device_ids is None: device_ids = [] device_ids = check_type.check_and_as_int_tuple_list("device_ids", device_ids, 0) check_type.check_int("num_parallel_workers", num_parallel_workers, 0) if dec_key is not None: if not isinstance(dec_key, bytes): raise RuntimeError(f"Parameter 'dec_key' should be bytes, but actually {type(dec_key)}") if not dec_key: raise RuntimeError(f"Parameter 'dec_key' should not be empty bytes") if len(dec_key) not in (16, 24, 32): raise RuntimeError(f"Parameter 'dec_key' length {len(dec_key)} expected to be 16, 24 or 32") check_type.check_str("dec_mode", dec_mode) if dec_mode not in ('AES-GCM', 'AES-CBC'): raise RuntimeError(f"Parameter 'dec_mode' expected to be 'AES-GCM' or 'AES-CBC'") self.version_number = version_number self.device_ids = set(device_ids) if not device_ids and not num_parallel_workers: self.num_parallel_workers = 1 else: self.num_parallel_workers = num_parallel_workers self.dec_key = dec_key self.dec_mode = dec_mode class ServableStartConfigGroup: """Servable start config for one servable with multi version deployment configs""" def __init__(self, servable_directory, servable_name, device_type=None): check_type.check_str("servable_directory", servable_directory) logger.info(f"input servable directory: {servable_directory}") servable_directory = get_abs_path(servable_directory) logger.info(f"absolute servable directory: {servable_directory}") check_type.check_str("servable_name", servable_name) if device_type is not None: check_type.check_str("device_type", device_type) else: device_type = "None" self.servable_directory = servable_directory self.servable_name = servable_name self.device_type = device_type self.check_servable_location() self.deploy_configs = {} self.newest_version_number = get_newest_version_number(servable_directory, servable_name) logger.info(f"The newest version number of servable {self.servable_name} is {self.newest_version_number}, " f"servable directory: {self.servable_directory}") def check_servable_location(self): """Check the validity of parameters servable_directory and servable_name""" config_dir = os.path.join(self.servable_directory, self.servable_name) if not os.path.isdir(config_dir): raise RuntimeError( f"Check servable config failed, directory '{config_dir}' not exist, servable " f"directory '{self.servable_directory}', servable name '{self.servable_name}'") config_file = os.path.join(config_dir, "servable_config.py") if not os.path.isfile(config_file): raise RuntimeError( f"Check servable config failed, file '{config_file}' not exist, servable directory " f"'{self.servable_directory}', servable name '{self.servable_name}'") def append_deploy(self, deploy_config): """Append one deployment configuration of one version for the servable""" if not isinstance(deploy_config, DeployConfig): raise RuntimeError(f"Parameter 'deploy_config' should be type of DeployConfig") if deploy_config.version_number == 0: deploy_config.version_number = self.newest_version_number if deploy_config.version_number not in self.deploy_configs: self.deploy_configs[deploy_config.version_number] = deploy_config else: last_config = self.deploy_configs[deploy_config.version_number] last_config.device_ids = last_config.device_ids.union(deploy_config.device_ids) if last_config.dec_key != deploy_config.dec_key or last_config.dec_mode != deploy_config.dec_mode: raise RuntimeError(f"The dec key or dec mode of servable name {self.servable_name} is different in " f"multiple configurations.") if deploy_config.num_parallel_workers > last_config.num_parallel_workers: last_config.num_parallel_workers = deploy_config.num_parallel_workers def export_as_start_configs(self): """Export the configuration as list of ServableStartConfig""" configs = [] for config in self.deploy_configs.values(): start_config = ServableStartConfig(servable_directory=self.servable_directory, servable_name=self.servable_name, device_ids=tuple(config.device_ids), version_number=config.version_number, device_type=self.device_type, num_parallel_workers=config.num_parallel_workers, dec_key=config.dec_key, dec_mode=config.dec_mode) configs.append(start_config) return configs def _check_and_merge_config(configs): """Merge ServableStartConfig with the same version number""" start_config_groups = {} for config in configs: if not isinstance(config, ServableStartConfig): continue if config.servable_name in start_config_groups: if config.servable_directory != start_config_groups[config.servable_name].servable_directory: raise RuntimeError( f"The servable directory of servable name {config.servable_name} is different in" f" multiple configurations, servable directory: " f"{config.servable_directory} and {start_config_groups[config.servable_name].servable_directory}") else: config_group = ServableStartConfigGroup(config.servable_directory, config.servable_name, config.device_type) start_config_groups[config.servable_name] = config_group deploy_config = DeployConfig(config.version_number, config.device_ids, config.num_parallel_workers, config.dec_key, config.dec_mode) start_config_groups[config.servable_name].append_deploy(deploy_config) return start_config_groups def merge_config(configs): """Merge ServableStartConfig with the same version number""" start_config_groups = _check_and_merge_config(configs) configs_ret = [] for config_group in start_config_groups.values(): start_configs = config_group.export_as_start_configs() configs_ret.extend(start_configs) allow_reuse_device = None device_ids_used = set() for config in configs_ret: for device_id in config.device_ids: if device_id in device_ids_used: if allow_reuse_device is None: allow_reuse_device = _all_reuse_device() if not allow_reuse_device: raise RuntimeError(f"Ascend 910 device id {device_id} is used repeatedly in servable " f"{config.servable_name}") device_ids_used.add(device_id) for config in configs: if not isinstance(config, ServableStartConfig): configs_ret.append(config) return configs_ret class ServableContextData(ServableContextDataBase): """Used to startup servable process""" def __init__(self, servable_config, device_id, master_address, enable_lite): super(ServableContextData, self).__init__() self.servable_config = servable_config self.device_id = device_id self.master_address = master_address self.log_new_file = True self.enable_lite = enable_lite @property def servable_name(self): return self.servable_config.servable_name @property def version_number(self): return self.servable_config.version_number def to_string(self): """For logging""" return f"servable name: {self.servable_name}, device id: {self.device_id}" def new_worker_process(self): """Start worker process to provide servable""" python_exe = sys.executable config = self.servable_config device_type = config.device_type if device_type is None: device_type = "None" script_dir = os.path.dirname(os.path.abspath(__file__)) py_script = os.path.join(script_dir, "start_worker.py") if self.servable_config.dec_key: pipe_file = f"serving_temp_dec_{config.servable_name}_device{self.device_id}_" \ f"{random.randrange(1000000, 9999999)}" os.mkfifo(pipe_file) else: pipe_file = 'None' enable_lite_str = "True" if self.enable_lite else "False" arg = f"{python_exe} {py_script} " \ f"--servable_directory={config.servable_directory} " \ f"--servable_name={config.servable_name} " \ f"--version_number={config.version_number} " \ f"--device_type={device_type} " \ f"--device_id={self.device_id} " \ f"--master_address={self.master_address} " \ f"--enable_lite={enable_lite_str} " \ f"--dec_key_pipe_file={pipe_file} " \ f"--dec_mode={config.dec_mode} " \ f"--listening_master=True" args = arg.split(" ") serving_logs_dir = "serving_logs" try: os.mkdir(serving_logs_dir) except FileExistsError: pass write_mode = "w" if self.log_new_file else "a" self.log_new_file = False log_file_name = f"{serving_logs_dir}/log_{config.servable_name}_device{self.device_id}" \ f"_version{self.version_number}.log" with open(log_file_name, write_mode) as fp: sub = subprocess.Popen(args=args, shell=False, stdout=fp, stderr=fp) if self.servable_config.dec_key: with open(pipe_file, "wb") as fp: fp.write(self.servable_config.dec_key) return sub class ServableExtraContextData(ServableContextDataBase): """Used to startup servable process""" def __init__(self, servable_config, master_address, index, device_ids_empty, enable_lite): super(ServableExtraContextData, self).__init__() self.servable_config = servable_config self.master_address = master_address self.log_new_file = True self.index = index self.device_ids_empty = device_ids_empty self.enable_lite = enable_lite @property def servable_name(self): return self.servable_config.servable_name @property def version_number(self): return self.servable_config.version_number def own_device(self): """Whether the worker occupy device""" return False def to_string(self): """For logging""" return f"servable name: {self.servable_name}, version: {self.version_number}, extra: {self.index}" def new_worker_process(self): """Start worker process to provide servable""" python_exe = sys.executable config = self.servable_config script_dir = os.path.dirname(os.path.abspath(__file__)) py_script = os.path.join(script_dir, "start_extra_worker.py") if config.dec_key: pipe_file = f"serving_temp_dec_{config.servable_name}_index{self.index}_" \ f"{random.randrange(1000000, 9999999)}" os.mkfifo(pipe_file) else: pipe_file = 'None' device_type = config.device_type if device_type is None: device_type = "None" enable_lite_str = "True" if self.enable_lite else "False" arg = f"{python_exe} {py_script} " \ f"--servable_directory={config.servable_directory} " \ f"--servable_name={config.servable_name} " \ f"--version_number={config.version_number} " \ f"--device_type={device_type} " \ f"--device_ids_empty={self.device_ids_empty} " \ f"--index={self.index} " \ f"--enable_lite={enable_lite_str} " \ f"--master_address={self.master_address} " \ f"--dec_key_pipe_file={pipe_file} " \ f"--dec_mode={config.dec_mode} " \ f"--listening_master=True" args = arg.split(" ") serving_logs_dir = "serving_logs" try: os.mkdir(serving_logs_dir) except FileExistsError: pass write_mode = "w" if self.log_new_file else "a" self.log_new_file = False log_file_name = f"{serving_logs_dir}/log_{config.servable_name}_extra{self.index}" \ f"_version{self.version_number}.log" with open(log_file_name, write_mode) as fp: sub = subprocess.Popen(args=args, shell=False, stdout=fp, stderr=fp) if self.servable_config.dec_key: with open(pipe_file, "wb") as fp: fp.write(self.servable_config.dec_key) return sub ================================================ FILE: mindspore_serving/server/_server.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Interface for start up servable""" import os import time import threading import signal import mindspore_serving.log as logger from mindspore_serving.server.worker.init_mindspore import set_mindspore_cxx_env from mindspore_serving.server.master import start_master_server, stop_on_except, stop, at_stop_list, only_model_stage from mindspore_serving.server._servable_common import WorkerContext from mindspore_serving.server._servable_local import ServableStartConfig, ServableContextData, merge_config from mindspore_serving.server._servable_local import ServableExtraContextData from mindspore_serving.server.distributed._servable_distributed import DistributedStartConfig, DistributedContextData from mindspore_serving.server.common import check_type from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import ServableContext_ @stop_on_except def start_servables(servable_configs, enable_lite=False): r""" Used to start one or more servables on the serving server. One model can be combined with preprocessing and postprocessing to provide a servable, and multiple models can also be combined to provide a servable. This interface can be used to start multiple different servables. One servable can be deployed on multiple devices, and each device runs a servable copy. On Ascend 910 hardware platform, each copy of each servable owns one device. Different servables or different versions of the same servable need to be deployed on different devices. On Ascend 310P and GPU hardware platform, one device can be shared by multi servables, and different servables or different versions of the same servable can be deployed on the same chip to realize device reuse. For details about how to configure models to provide servables, please refer to `MindSpore-based Inference Service Deployment `_ and `Servable Provided Through Model Configuration `_. Args: servable_configs (Union[ServableStartConfig, list[ServableStartConfig], tuple[ServableStartConfig]]): The startup configs of one or more servables. enable_lite (bool): Whether to use MindSpore Lite inference backend. Default False. Raises: RuntimeError: Failed to start one or more servables. For log of one servable, please refer to subdirectory serving_logs of the directory where the startup script is located. Examples: >>> import os >>> from mindspore_serving import server >>> >>> servable_dir = os.path.abspath(".") >>> resnet_config = server.ServableStartConfig(servable_dir, "resnet", device_ids=(0,1)) >>> add_config = server.ServableStartConfig(servable_dir, "add", device_ids=(2,3)) >>> server.start_servables(servable_configs=(resnet_config, add_config)) # press Ctrl+C to stop >>> server.start_grpc_server("0.0.0.0:5500") """ if isinstance(servable_configs, (ServableStartConfig, DistributedStartConfig)): servable_configs = (servable_configs,) if not isinstance(servable_configs, (tuple, list)): raise RuntimeError(f"Parameter '{servable_configs}' should be ServableStartConfig, list or tuple of " f"ServableStartConfig, but actually {type(servable_configs)}") check_type.check_bool("enable_lite", enable_lite) for config in servable_configs: if not isinstance(config, (ServableStartConfig, DistributedStartConfig)): raise RuntimeError( f"The item of parameter '{servable_configs}' should be ServableStartConfig, but actually " f"{type(config)}") if isinstance(config, ServableStartConfig): # pylint: disable=protected-access config._check_device_type(enable_lite) ServableContext_.get_instance().set_enable_lite(enable_lite) set_mindspore_cxx_env() # merge ServableStartConfig with same servable name and running version number try: servable_configs = merge_config(servable_configs) except RuntimeError as e: logger.error(f"Start servables failed: {str(e)}") raise logger.info("Servable configs:") for config in servable_configs: if isinstance(config, ServableStartConfig): logger.info( f"servable directory: {config.servable_directory}, servable name: {config.servable_name}, " f"running version number: {config.version_number}, device ids:{config.device_ids}, " f"device type: {config.device_type}") if isinstance(config, DistributedStartConfig): logger.info(f"distributed servable, servable directory: {config.servable_directory}, " f"servable name: {config.servable_name}, rank table json file: {config.rank_table_json_file}, " f"running version number: {config.version_number}, " f"distributed address:{config.distributed_address}, " f"wait agents time: {config.wait_agents_time_in_seconds}s") master_pid = os.getpid() unix_socket_dir = "unix_socket_files" try: os.mkdir(unix_socket_dir) except FileExistsError: pass master_address = f"unix:{unix_socket_dir}/serving_master_{master_pid}" start_master_server(address=master_address) signal.signal(signal.SIGCHLD, signal.SIG_IGN) worker_list = _start_workers_with_devices(master_address, servable_configs, enable_lite) has_device_workers = bool(worker_list) _listening_workers_when_startup(worker_list) extra_worker_list = _start_extra_workers(master_address, servable_configs, enable_lite) worker_list.extend(extra_worker_list) _listening_workers_after_startup(worker_list, has_device_workers) def _start_workers_with_devices(master_address, servable_configs, enable_lite): """Start workers that occupy devices""" worker_list = [] for config in servable_configs: if isinstance(config, ServableStartConfig): for device_id in config.device_ids: try: context_data = ServableContextData(config, device_id, master_address, enable_lite) sub_process = context_data.new_worker_process() worker_context = WorkerContext(context_data, master_address, sub_process) except RuntimeError as e: _send_exit_signal_to_children(worker_list) raise RuntimeError(f"Start worker failed: {e}") worker_list.append(worker_context) elif isinstance(config, DistributedStartConfig): try: context_data = DistributedContextData(config, master_address) sub_process = context_data.new_worker_process() worker_context = WorkerContext(context_data, master_address, sub_process) except RuntimeError as e: _send_exit_signal_to_children(worker_list) raise RuntimeError(f"Start worker failed: {e}") worker_list.append(worker_context) return worker_list def _start_extra_workers(master_address, servable_configs, enable_lite): """Start workers that do not occupy devices""" worker_list = [] worker_pid_set = set() for config in servable_configs: if not isinstance(config, ServableStartConfig): continue if len(config.device_ids) >= config.num_parallel_workers: continue if only_model_stage(config.servable_name): logger.warning(f"There is no need to startup additional worker processes, all stages are models, servable:" f" {config.servable_name}") continue extra_worker_count = config.num_parallel_workers - len(config.device_ids) for index in range(extra_worker_count): try: context_data = ServableExtraContextData(config, master_address, index, not config.device_ids, enable_lite) sub_process = context_data.new_worker_process() if sub_process.pid in worker_pid_set: raise RuntimeError( f"Maybe the parameter 'num_parallel_workers' is too large, and the number of open files exceeds" f" the system upper limit. Please check the workers logs in the serving_logs directory for" f" more details") worker_pid_set.add(sub_process.pid) worker_context = WorkerContext(context_data, master_address, sub_process) except RuntimeError as e: _send_exit_signal_to_children(worker_list) raise RuntimeError(f"Start worker failed: {e}") worker_list.append(worker_context) _listening_workers_when_startup(worker_list) return worker_list def _send_exit_signal_to_children(worker_list): """Send exit signal to all child processes, and terminate all child processes when they are still alive in some seconds later. """ if not worker_list: return for worker in worker_list: worker.send_exit_signal(signal.SIGINT) wait_seconds = 10 for i in range(wait_seconds * 100): # 10s all_exit = True for worker in worker_list: if worker.is_alive(): if i % 100 == 0: logger.warning(f"Wait for all worker processes to exit, otherwise they will be forcibly killed in " f"{wait_seconds - (i // 100)} seconds.") all_exit = False break if all_exit: logger.info(f"All Child process exited") return time.sleep(0.01) for worker in worker_list: worker.send_exit_signal(signal.SIGKILL) def _listening_workers_when_startup(worker_list): """Listening child process""" if not worker_list: return time_last = time.time() while True: time.sleep(0.1) if ExitSignalHandle_.has_stopped(): logger.warning("Fail to start workers because of signal SIGINT or SIGTERM") _send_exit_signal_to_children(worker_list) raise RuntimeError("Fail to start workers because of signal SIGINT or SIGTERM") all_ready = True for worker in worker_list: if not worker.is_alive() or worker.has_error_notified(): for _ in range(100): if worker.has_error_notified(): logger.warning(f"Fail to start workers: {worker.get_notified_error()}") _send_exit_signal_to_children(worker_list) raise RuntimeError(f"Fail to start workers: {worker.get_notified_error()}") time.sleep(0.01) # wait 1s for error msg logger.error(f"Fail to start workers because of death of one worker") _send_exit_signal_to_children(worker_list) raise RuntimeError("Fail to start workers because of death of one worker") if not worker.ready(): if time.time() - time_last > 1: time_last = time.time() worker.print_status() all_ready = False if all_ready: break logger.info("All workers is ready") def _listening_workers_after_startup(worker_list, has_device_workers): """Listening agent status after success start up of agents""" def listening_thread_fun(): while True: time.sleep(0.01) if ExitSignalHandle_.has_stopped(): logger.warning("Serving server begin to exit: receive exit signal") break alive_count = 0 for worker in worker_list: occupy_device_worker = 1 if worker.own_device() or not has_device_workers else 0 if worker.is_in_process_switching: alive_count += occupy_device_worker continue if worker.is_alive(): alive_count += occupy_device_worker if worker.is_unavailable(): worker.restart_worker() continue # not alive # has exit or error notified, if worker.has_exit_notified() or worker.has_error_notified(): continue if worker.exit_for_enough_time(): # has exit for 1s and there were no normal handled requests if not worker.can_be_restart(): continue logger.warning( f"detect worker process has exited, try to restart, servable: {worker.to_string()}") worker.restart_worker() alive_count += occupy_device_worker if not alive_count: logger.warning("Serving server begin to exit: all worker processes that occupy devices have exited") break _send_exit_signal_to_children(worker_list) stop() thread = threading.Thread(target=listening_thread_fun) thread.start() def join_thread(): if thread != threading.current_thread(): thread.join() return True return False at_stop_list.append(join_thread) ================================================ FILE: mindspore_serving/server/common/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore Serving.""" from . import check_type from .utils import get_abs_path from .decorator import deprecated ================================================ FILE: mindspore_serving/server/common/check_type.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """T check for worker""" def check_and_as_tuple_with_str_list(arg_name, strs): """Check whether the input parameters are reasonable multiple str inputs, which can be single str, tuple or list of str, tuple with list of str. Finally, return tuple with list of str. """ if isinstance(strs, str): strs = (list(strs),) return tuple(strs) if not isinstance(strs, (tuple, list)): raise RuntimeError(f"Parameter '{arg_name}' should be str or tuple/list of str, but actually {type(strs)}") str_list = [] for item in strs: it_list = [] if isinstance(item, list): for inner in item: if not isinstance(inner, str): raise RuntimeError(f"The inner of parameter '{arg_name}' should be str, " f"but actually {type(inner)}") if not inner: raise RuntimeError(f"The inner of parameter '{arg_name}' should not be empty str") if item in it_list: raise RuntimeError(f"The inner value '{inner}' in parameter '{arg_name}' " f"should not be repeated") it_list.append(inner) else: if not isinstance(item, str): raise RuntimeError(f"The item of parameter '{arg_name}' should be str, but actually {type(item)}") if not item: raise RuntimeError(f"The item of parameter '{arg_name}' should not be empty str") if item in str_list: raise RuntimeError(f"The item value '{item}' in parameter '{arg_name}' should not be repeated") it_list.append(item) str_list.append(it_list) return tuple(str_list) def check_and_as_str_tuple_list(arg_name, strs): """Check whether the input parameters are reasonable multiple str inputs, which can be single str, tuple or list of str. Finally, return tuple of str. """ if isinstance(strs, str): strs = (strs,) if not isinstance(strs, (tuple, list)): raise RuntimeError(f"Parameter '{arg_name}' should be str or tuple/list of str, but actually {type(strs)}") str_list = [] for item in strs: if not isinstance(item, str): raise RuntimeError(f"The item of parameter '{arg_name}' should be str, but actually {type(item)}") if not item: raise RuntimeError(f"The item of parameter '{arg_name}' should not be empty str") if item in str_list: raise RuntimeError(f"The item value '{item}' in parameter '{arg_name}' should not be repeated") str_list.append(item) return tuple(str_list) def check_str(arg_name, str_val): """Check whether the input parameters are reasonable str input""" if not isinstance(str_val, str): raise RuntimeError(f"Parameter '{arg_name}' should be str, but actually {type(str_val)}") if not str_val: raise RuntimeError(f"Parameter '{arg_name}' should not be empty str") def check_bytes(arg_name, bytes_val): """Check whether the input parameters are reasonable bytes input""" if not isinstance(bytes_val, bytes): raise RuntimeError(f"Parameter '{arg_name}' should be bytes, but actually {type(bytes_val)}") if not bytes_val: raise RuntimeError(f"Parameter '{arg_name}' should not be empty bytes") def check_bool(arg_name, bool_val): """Check whether the input parameters are reasonable bool input""" if not isinstance(bool_val, bool): raise RuntimeError(f"Parameter '{arg_name}' should be bool, but actually {type(bool_val)}") def check_int(arg_name, int_val, minimum=None, maximum=None, is_tuple_item=False): """Check whether the input parameters are reasonable int input""" if not is_tuple_item: prefix = f"Parameter '{arg_name}'" else: prefix = f"The item value '{int_val}' in parameter '{arg_name}'" if isinstance(int_val, bool): raise RuntimeError(f"{prefix} should be int, but actually {type(int_val)}") if not isinstance(int_val, int): raise RuntimeError(f"{prefix} should be int, but actually {type(int_val)}") if minimum is not None and int_val < minimum: if maximum is not None: raise RuntimeError(f"{prefix} should be in range [{minimum},{maximum}]") raise RuntimeError(f"{prefix} should be >= {minimum}") if maximum is not None and int_val > maximum: if minimum is not None: raise RuntimeError(f"{prefix} should be in range [{minimum},{maximum}]") raise RuntimeError(f"{prefix} should be <= {maximum}") def check_ip_port(arg_name, port): """Check whether the input parameters are reasonable ip port""" check_int(arg_name, port, 1, 65535) def check_and_as_int_tuple_list(arg_name, ints, minimum=None, maximum=None): """Check whether the input parameters are reasonable multiple int inputs, which can be single int, tuple or list of int. Finally, return tuple of int. """ if isinstance(ints, int): ints = (ints,) if not isinstance(ints, (tuple, list)): raise RuntimeError(f"Parameter '{arg_name}' should be int or tuple/list of int, but actually {type(ints)}") int_list = [] for item in ints: if item in int_list: raise RuntimeError(f"The item value '{item}' in parameter '{arg_name}' should not be repeated") check_int(arg_name, item, minimum, maximum, True) int_list.append(item) return tuple(int_list) def check_int_tuple_list(arg_name, ints, minimum=None, maximum=None): """Check whether the input parameters are reasonable multiple int inputs, which can be single tuple or list of int. Finally, return tuple of int. """ if not isinstance(ints, (tuple, list)): raise RuntimeError(f"Parameter '{arg_name}' should be tuple/list of int, but actually {type(ints)}") int_list = [] for item in ints: if item in int_list: raise RuntimeError(f"The item value '{item}' in parameter '{arg_name}' should not be repeated") check_int(arg_name, item, minimum, maximum, True) int_list.append(item) ================================================ FILE: mindspore_serving/server/common/decorator.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Providing decorators.""" from functools import wraps from mindspore_serving import log def deprecated(version, substitute): """deprecated warning Args: version (str): version that the operator or function is deprecated. substitute (str): the substitute name for deprecated operator or function. """ def decorate(func): @wraps(func) def wrapper(*args, **kwargs): name = func.__name__ log.warning(f"'{name}' is deprecated from version {version} and " f"will be removed in a future version, use '{substitute}' instead.") ret = func(*args, **kwargs) return ret return wrapper return decorate ================================================ FILE: mindspore_serving/server/common/utils.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """common function utils""" import os import sys def get_abs_path(path): """get the absolute path""" script_dir = os.path.dirname(os.path.realpath(sys.argv[0])) abs_path = os.path.realpath(os.path.join(script_dir, path)) return abs_path ================================================ FILE: mindspore_serving/server/distributed/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The interface to startup serving server with distributed servable. See how to configure and startup distributed model, please refer to `MindSpore Serving-based Distributed Inference Service Deployment `_.""" from mindspore_serving.server.worker.distributed import startup_agents from mindspore_serving.server.worker.distributed.register import declare_servable from ._distributed import start_servable __all__ = [] __all__.extend([ "start_servable", 'startup_agents', 'declare_servable' ]) ================================================ FILE: mindspore_serving/server/distributed/_distributed.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Startup serving server with distributed servable""" from ._servable_distributed import DistributedStartConfig def start_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, distributed_address="0.0.0.0:6200", wait_agents_time_in_seconds=0): r""" Start up the servable named 'servable_name' defined in 'servable_directory'. Args: servable_directory (str): The directory where the servable is located in. There expects to has a directory named `servable_name`. For more detail: `How to config Servable `_ . servable_name (str): The servable name. version_number (int, optional): Servable version number to be loaded. The version number should be a positive integer, starting from 1. Default: 1. rank_table_json_file (str): The rank table json file name. distributed_address (str, optional): The distributed worker address the worker agents linked to. Default: "0.0.0.0:6200". wait_agents_time_in_seconds(int, optional): The maximum time in seconds the worker waiting ready of all agents, 0 means unlimited time. Default: 0. Raises: RuntimeError: Failed to start the distributed servable. Examples: >>> import os >>> from mindspore_serving.server import distributed >>> >>> servable_dir = os.path.abspath(".") >>> distributed.start_servable(servable_dir, "matmul", startup_worker_agents="hccl_8p.json", \ ... distributed_address="127.0.0.1:6200") """ from mindspore_serving.server import start_servables config = DistributedStartConfig(servable_directory=servable_directory, servable_name=servable_name, rank_table_json_file=rank_table_json_file, version_number=version_number, distributed_address=distributed_address, wait_agents_time_in_seconds=wait_agents_time_in_seconds) start_servables(config) ================================================ FILE: mindspore_serving/server/distributed/_servable_distributed.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Distributed servable config""" import os import sys import subprocess from mindspore_serving.server.common import check_type, get_abs_path import mindspore_serving.log as logger from mindspore_serving.server._servable_common import ServableContextDataBase class DistributedStartConfig: r""" Distributed servable start-up config. Args: servable_directory (str): The directory where the servable is located in. There expects to has a directory named `servable_name`. For more detail: `How to config Servable `_ . servable_name (str): The servable name. rank_table_json_file (str): The rank table json file name. version_number (int): Servable version number to be loaded. The version number should be a positive integer, starting from 1, and 0 means to load the latest version. Default: 0. distributed_address (str): The worker address the agents linked to. wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents, 0 means unlimited time, default 0 Raises: RuntimeError: Input parameters are invalid. """ def __init__(self, servable_directory, servable_name, rank_table_json_file, version_number, distributed_address, wait_agents_time_in_seconds): super(DistributedStartConfig, self).__init__() check_type.check_str('servable_directory', servable_directory) logger.info(f"input servable directory: {servable_directory}") servable_directory = get_abs_path(servable_directory) logger.info(f"absolute servable directory: {servable_directory}") check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) if version_number == 0: version_number = 1 check_type.check_str('rank_table_json_file', rank_table_json_file) logger.info(f"input rank table file: {rank_table_json_file}") rank_table_json_file = get_abs_path(rank_table_json_file) logger.info(f"absolute path of rank table file: {rank_table_json_file}") check_type.check_str('distributed_address', distributed_address) check_type.check_int('wait_agents_time_in_seconds', wait_agents_time_in_seconds, 0) self.servable_directory_ = servable_directory self.servable_name_ = servable_name self.version_number_ = version_number self.rank_table_json_file_ = rank_table_json_file self.distributed_address_ = distributed_address self.wait_agents_time_in_seconds_ = wait_agents_time_in_seconds @property def servable_directory(self): return self.servable_directory_ @property def servable_name(self): return self.servable_name_ @property def version_number(self): return self.version_number_ @property def rank_table_json_file(self): return self.rank_table_json_file_ @property def distributed_address(self): return self.distributed_address_ @property def wait_agents_time_in_seconds(self): return self.wait_agents_time_in_seconds_ class DistributedContextData(ServableContextDataBase): """Used to start distributed servable worker process""" def __init__(self, distributed_config, master_address): super(DistributedContextData, self).__init__() if not isinstance(distributed_config, DistributedStartConfig): raise RuntimeError(f"Parameter '{distributed_config}' should be instance of DistributedStartConfig, " f"but actually {type(distributed_config)}") self.distributed_config_ = distributed_config self.master_address_ = master_address self.log_new_file = True @property def servable_name(self): return self.distributed_config_.servable_name @property def version_number(self): return self.distributed_config_.version_number def to_string(self): """Used in logging""" return f"distributed servable name: {self.servable_name}" def new_worker_process(self): """Start distributed worker process""" python_exe = sys.executable script_dir = os.path.dirname(os.path.abspath(__file__)) py_script = os.path.join(script_dir, "start_distributed_worker.py") config = self.distributed_config_ arg = f"{python_exe} {py_script} {config.servable_directory} {config.servable_name} " \ f"{config.version_number} {config.rank_table_json_file} {config.distributed_address} " \ f"{config.wait_agents_time_in_seconds} {self.master_address_} True" args = arg.split(" ") serving_logs_dir = "serving_logs" try: os.mkdir(serving_logs_dir) except FileExistsError: pass write_mode = "w" if self.log_new_file else "a" self.log_new_file = False log_file_name = f"{serving_logs_dir}/log_{self.servable_name}_distributed.log" with open(log_file_name, write_mode) as fp: sub = subprocess.Popen(args=args, shell=False, stdout=fp, stderr=fp) return sub def can_restart(self): """Whether the worker can restart""" return False ================================================ FILE: mindspore_serving/server/distributed/start_distributed_worker.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start distributed worker process""" import os import sys from mindspore_serving.server.worker import distributed from mindspore_serving.server.common import check_type from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ def start_worker(servable_directory, servable_name, version_number, rank_table_json_file, distributed_address, wait_agents_time_in_seconds, master_address, listening_master=False): """Start distributed worker process""" check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) check_type.check_str('rank_table_json_file', rank_table_json_file) check_type.check_str('distributed_address', distributed_address) check_type.check_int('wait_agents_time_in_seconds', wait_agents_time_in_seconds, 0) check_type.check_str('master_address', master_address) check_type.check_bool('listening_master', listening_master) ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message worker_pid = os.getpid() unix_socket_dir = "unix_socket_files" try: os.mkdir(unix_socket_dir) except FileExistsError: pass worker_address = f"unix:{unix_socket_dir}/serving_worker_{servable_name}_distributed_{worker_pid}" if len(worker_address) > 90: # limit maximum unix domain socket address length worker_address = worker_address[:40] + "___" + worker_address[-40:] try: distributed.start_servable(servable_directory=servable_directory, servable_name=servable_name, version_number=version_number, rank_table_json_file=rank_table_json_file, distributed_address=distributed_address, wait_agents_time_in_seconds=wait_agents_time_in_seconds, master_address=master_address, worker_address=worker_address) except RuntimeError as ex: Worker_.notify_failed(master_address, f"{{distributed servable:{servable_name}, {ex}}}") raise def parse_args_and_start(): """Parse args and start distributed worker""" if len(sys.argv) != 9: raise RuntimeError("Expect length of input argv to be 8: str{servable_directory} str{servable_name} " "int{version_number} str{rank_table_json_file} str{distributed_address} " "int{wait_agents_time_in_seconds} str{master_address} bool{listening_master}") servable_directory = sys.argv[1] servable_name = sys.argv[2] version_number = int(sys.argv[3]) rank_table_json_file = sys.argv[4] distributed_address = sys.argv[5] wait_agents_time_in_seconds = int(sys.argv[6]) master_address = sys.argv[7] # pylint: disable=simplifiable-if-expression listening_master = True if sys.argv[8].lower() == "true" else False start_worker(servable_directory, servable_name, version_number, rank_table_json_file, distributed_address, wait_agents_time_in_seconds, master_address, listening_master) if __name__ == '__main__': parse_args_and_start() ================================================ FILE: mindspore_serving/server/master/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The master process of serving server: used to receive requests and dispatcher them to worker process""" from ._master import start_grpc_server, start_restful_server, stop, stop_on_except, SSLConfig from ._master import start_master_server, at_stop_list, only_model_stage from . import context ================================================ FILE: mindspore_serving/server/master/_master.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """method of server supplied for master""" from functools import wraps from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Master_ from mindspore_serving._mindspore_serving import SSLConfig_ from mindspore_serving.server.common import check_type _wait_and_clear_thread = None at_stop_list = [] def add_atstop_proc(func): """At serving server stop, execute function""" global at_stop_list at_stop_list.append(func) def stop(): r""" Stop the running of serving server. Examples: >>> from mindspore_serving import server >>> >>> server.start_grpc_server("0.0.0.0:5500") >>> server.start_restful_server("0.0.0.0:1500") >>> ... >>> server.stop() """ Master_.stop_and_clear() global at_stop_list for func in at_stop_list: result = func() if result is None or result is True: at_stop_list.remove(func) def stop_on_except(func): """Wrap of clear environment and exit on Serving exception""" @wraps(func) def handle_except(*args, **kwargs): try: ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() raise return handle_except class SSLConfig: r""" The server's ssl_config encapsulates necessary parameters for SSL-enabled connections. Args: certificate (str): File holding the PEM-encoded certificate chain as a byte string to use or None if no certificate chain should be used. private_key (str): File holding the PEM-encoded private key as a byte string, or None if no private key should be used. custom_ca (str, optional): File holding the PEM-encoded root certificates as a byte string. When verify_client is True, custom_ca must be provided. When verify_client is False, this parameter will be ignored. Default: None. verify_client (bool, optional): If verify_client is true, use mutual authentication. If false, use one-way authentication. Default: False. Raises: RuntimeError: The type or value of the parameters are invalid. """ def __init__(self, certificate, private_key, custom_ca=None, verify_client=False): check_type.check_str("certificate", certificate) check_type.check_str("private_key", private_key) check_type.check_bool("verify_client", verify_client) self.custom_ca = custom_ca self.certificate = certificate self.private_key = private_key self.verify_client = verify_client if self.verify_client: check_type.check_str("custom_ca", custom_ca) @stop_on_except def start_grpc_server(address, max_msg_mb_size=100, ssl_config=None): r""" Start gRPC server for the communication between serving client and server. Args: address (str): gRPC server address, the address can be `{ip}:{port}` or `unix:{unix_domain_file_path}`. - `{ip}:{port}` - Internet domain socket address. - `unix:{unix_domain_file_path}` - Unix domain socket address, which is used to communicate with multiple processes on the same machine. `{unix_domain_file_path}` can be relative or absolute file path, but the directory where the file is located must already exist. max_msg_mb_size (int, optional): The maximum acceptable gRPC message size in megabytes(MB), value range [1, 512]. Default: 100. ssl_config (mindspore_serving.server.SSLConfig, optional): The server's ssl_config, if None, disabled ssl. Default: None. Raises: RuntimeError: Failed to start the gRPC server: parameter verification failed, the gRPC address is wrong or the port is duplicate. Examples: >>> from mindspore_serving import server >>> >>> server.start_grpc_server("0.0.0.0:5500") """ check_type.check_str('address', address) check_type.check_int('max_msg_mb_size', max_msg_mb_size, 1, 512) config = SSLConfig_() if ssl_config is not None: if not isinstance(ssl_config, SSLConfig): raise RuntimeError("The type of ssl_config should be type of SSLConfig") with open(ssl_config.certificate, 'rb') as c_fs: c_bytes = c_fs.read() with open(ssl_config.private_key, 'rb') as pk_fs: pk_bytes = pk_fs.read() if ssl_config.verify_client: with open(ssl_config.custom_ca, 'rb') as rc_fs: rc_bytes = rc_fs.read() config.custom_ca = rc_bytes config.certificate = c_bytes config.private_key = pk_bytes config.verify_client = ssl_config.verify_client config.use_ssl = True Master_.start_grpc_server(address, config, max_msg_mb_size) @stop_on_except def start_restful_server(address, max_msg_mb_size=100, ssl_config=None): r""" Start RESTful server for the communication between serving client and server. Args: address (str): RESTful server address, the address should be Internet domain socket address. max_msg_mb_size (int, optional): The maximum acceptable RESTful message size in megabytes(MB), value range [1, 512]. Default: 100. ssl_config (mindspore_serving.server.SSLConfig, optional): The server's ssl_config, if None, disabled ssl. Default: None. Raises: RuntimeError: Failed to start the RESTful server: parameter verification failed, the RESTful address is wrong or the port is duplicate. Examples: >>> from mindspore_serving import server >>> >>> server.start_restful_server("0.0.0.0:5900") """ check_type.check_str('address', address) check_type.check_int('max_msg_mb_size', max_msg_mb_size, 1, 512) config = SSLConfig_() if ssl_config is not None: if not isinstance(ssl_config, SSLConfig): raise RuntimeError("The type of ssl_config should be class of SSLConfig") if ssl_config.verify_client: config.custom_ca = ssl_config.custom_ca config.certificate = ssl_config.certificate config.private_key = ssl_config.private_key config.verify_client = ssl_config.verify_client config.use_ssl = True Master_.start_restful_server(address, config, max_msg_mb_size) def start_master_server(address): """Start the gRPC server for the communication between workers and the master of serving server""" check_type.check_str('address', address) Master_.start_grpc_master_server(address) def only_model_stage(servable_name): """Whether only the model stages exist""" return Master_.only_model_stage(servable_name) ================================================ FILE: mindspore_serving/server/master/context.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Set context of serving""" from mindspore_serving._mindspore_serving import MasterContext_ from mindspore_serving.server.common import check_type __all__ = ["set_max_enqueued_requests"] _context = MasterContext_.get_instance() def set_max_enqueued_requests(max_enqueued_requests): r""" Set the maximum number of requests waiting to be processed. Args: max_enqueued_requests (int): The maximum acceptable infer message size in number, default ``10000``, Max infer number should be a positive integer. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. """ check_type.check_int("max_enqueued_requests", max_enqueued_requests, 1) _context.set_max_enqueued_requests(max_enqueued_requests) ================================================ FILE: mindspore_serving/server/register/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Servable register interface, used in servable_config.py of one servable. See how to configure servable_config.py file, please refer to `Servable Provided Through Model Configuration `_.""" from .model import declare_model, Model, Context, AclOptions, GpuOptions from .model import AscendDeviceInfo, CPUDeviceInfo, GPUDeviceInfo from .method import register_method, add_stage from .model import declare_servable from .method import call_preprocess, call_servable, call_postprocess from .method import call_preprocess_pipeline, call_postprocess_pipeline __all__ = [] __all__.extend([ "declare_model", "Model", "AscendDeviceInfo", "CPUDeviceInfo", "GPUDeviceInfo", "Context", 'register_method', 'add_stage' ]) ================================================ FILE: mindspore_serving/server/register/method.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Method registration interface""" import inspect import ast from functools import wraps from mindspore_serving._mindspore_serving import ServableRegister_ from mindspore_serving._mindspore_serving import MethodSignature_ from mindspore_serving import log as logger from mindspore_serving.server.common import check_type, deprecated from .utils import get_func_name, get_servable_dir from .stage_function import register_stage_function, check_stage_function from .model import g_declared_models, Model method_def_context_ = MethodSignature_() cur_stage_index_ = 0 has_called_preprocess_ = False has_called_servable_ = False has_called_postprocess_ = False method_def_ast_meta_ = [] class _TensorDef: """Data flow item, for definitions of data flow in a method""" def __init__(self, tag, tensor_index): self.tag = tag self.tensor_index = tensor_index def as_pair(self): return self.tag, self.tensor_index def _create_tensor_def_outputs(tag, outputs_cnt): """Create data flow item for output""" result = [_TensorDef(tag, i) for i in range(outputs_cnt)] if len(result) == 1: return result[0] return tuple(result) def _wrap_fun_to_batch(fun, input_count): """wrap preprocess and postprocess to pipeline""" argspec_len = len(inspect.signature(fun).parameters) if argspec_len != input_count: raise RuntimeError(f"function {fun.__name__} input args count {argspec_len} not match the count {input_count} " f"registered in method") @wraps(fun) def call_func(instances): for instance in instances: inputs = [] for i in range(input_count): inputs.append(instance[i]) yield fun(*inputs) return call_func def _get_stage_outputs_count(call_name): global method_def_ast_meta_ method_name = method_def_context_.method_name if call_name not in method_def_ast_meta_: raise RuntimeError( f"Failed to parse method '{method_name}', complex statements such as conditions and loops are not supported" f" in register_method when the interface '{call_name}' is used, use 'add_stage' to replace '{call_name}'") _, outputs_count = method_def_ast_meta_[call_name] return outputs_count @deprecated("1.5.0", "mindspore_serving.server.register.add_stage") def call_preprocess(preprocess_fun, *args): r"""For method registration, define the preprocessing function and its' parameters. .. warning:: 'call_preprocess' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.add_stage` instead. Note: The length of 'args' should be equal to the inputs number of preprocess_fun. Args: preprocess_fun (function): Python function for preprocess. args: Preprocess inputs. The length of 'args' should equal to the input parameters number of implemented python function. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. Examples: >>> from mindspore_serving.server import register >>> import numpy as np >>> def add_trans_datatype(x1, x2): ... return x1.astype(np.float32), x2.astype(np.float32) >>> >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) >>> >>> @register.register_method(output_names=["y"]) # register add_cast method in add >>> def add_cast(x1, x2): ... x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32 ... y = register.call_servable(x1, x2) ... return y """ global method_def_context_ global has_called_preprocess_, has_called_servable_, has_called_postprocess_ if has_called_preprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_preprocess or call_preprocess_pipeline should not be invoked more than once") if has_called_servable_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_servable should be invoked after call_preprocess") if has_called_postprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess") has_called_preprocess_ = True outputs_count = _get_stage_outputs_count('call_preprocess') return add_stage(preprocess_fun, *args, outputs_count=outputs_count, tag="Preprocess") @deprecated("1.5.0", "mindspore_serving.server.register.add_stage") def call_preprocess_pipeline(preprocess_fun, *args): r"""For method registration, define the preprocessing pipeline function and its' parameters. .. warning:: 'call_preprocess_pipeline' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.add_stage` instead. A single request can include multiple instances, so multiple queued requests will also have multiple instances. If you need to process multiple instances through multi thread or other parallel processing capability in `preprocess` or `postprocess`, such as using MindData concurrency ability to process multiple input images in `preprocess`, MindSpore Serving provides 'call_preprocess_pipeline' and 'call_postprocess_pipeline' to register such preprocessing and postprocessing. For more detail, please refer to `Resnet50 model configuration example `_. Args: preprocess_fun (function): Python pipeline function for preprocess. args: Preprocess inputs. The length of 'args' should equal to the input parameters number of implemented python function. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. Examples: >>> from mindspore_serving.server import register >>> import numpy as np >>> def add_trans_datatype(instances): ... for instance in instances: ... x1 = instance[0] ... x2 = instance[0] ... yield x1.astype(np.float32), x2.astype(np.float32) >>> >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) >>> >>> @register.register_method(output_names=["y"]) # register add_cast method in add >>> def add_cast(x1, x2): ... x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) # cast input to float32 ... y = register.call_servable(x1, x2) ... return y """ global method_def_context_ global has_called_preprocess_, has_called_servable_, has_called_postprocess_ if has_called_preprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_preprocess or call_preprocess_pipeline should not be invoked more than once") if has_called_servable_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_servable should be invoked after call_preprocess_pipeline") if has_called_postprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', call_postprocess " f"or call_postprocess_pipeline should be invoked after call_preprocess_pipeline") has_called_preprocess_ = True outputs_count = _get_stage_outputs_count('call_preprocess_pipeline') return add_stage(preprocess_fun, *args, outputs_count=outputs_count, batch_size=0, tag="Preprocess") @deprecated("1.5.0", "mindspore_serving.server.register.add_stage") def call_postprocess(postprocess_fun, *args): r"""For method registration, define the postprocessing function and its' parameters. .. warning:: 'call_postprocess' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.add_stage` instead. Note: The length of 'args' should be equal to the inputs number of postprocess_fun. Args: postprocess_fun (function): Python function for postprocess. args: Preprocess inputs. The length of 'args' should equal to the input parameters number of implemented python function. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. """ global method_def_context_ global has_called_postprocess_ if has_called_postprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_postprocess or call_postprocess_pipeline should not be invoked more than once") has_called_postprocess_ = True outputs_count = _get_stage_outputs_count('call_postprocess') return add_stage(postprocess_fun, *args, outputs_count=outputs_count, tag="Postprocess") @deprecated("1.5.0", "mindspore_serving.server.register.add_stage") def call_postprocess_pipeline(postprocess_fun, *args): r"""For method registration, define the postprocessing pipeline function and its' parameters. .. warning:: 'call_postprocess_pipeline' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.add_stage` instead. A single request can include multiple instances, so multiple queued requests will also have multiple instances. If you need to process multiple instances through multi thread or other parallel processing capability in `preprocess` or `postprocess`, such as using MindData concurrency ability to process multiple input images in `preprocess`, MindSpore Serving provides 'call_preprocess_pipeline' and 'call_postprocess_pipeline' to register such preprocessing and postprocessing. For more detail, please refer to `Resnet50 model configuration example `_. Args: postprocess_fun (function): Python pipeline function for postprocess. args: Preprocess inputs. The length of 'args' should equal to the input parameters number of implemented python function. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. """ global method_def_context_ global has_called_postprocess_ if has_called_postprocess_: raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', " f"call_postprocess or call_postprocess_pipeline should not be invoked more than once") has_called_postprocess_ = True outputs_count = _get_stage_outputs_count('call_postprocess_pipeline') return add_stage(postprocess_fun, *args, outputs_count=outputs_count, batch_size=0, tag="Postprocess") @deprecated("1.5.0", "mindspore_serving.server.register.add_stage") def call_servable(*args): r"""For method registration, define the inputs data of model inference. .. warning:: 'call_servable' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.add_stage` instead. Note: The length of 'args' should be equal to the inputs number of model. Args: args: Model's inputs, the length of 'args' should be equal to the inputs number of model. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. Examples: >>> from mindspore_serving.server import register >>> register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) >>> >>> @register.register_method(output_names=["y"]) # register add_common method in add >>> def add_common(x1, x2): ... y = register.call_servable(x1, x2) ... return y """ global method_def_context_ global has_called_servable_, has_called_postprocess_ method_name = method_def_context_.method_name if has_called_servable_: raise RuntimeError(f"Check failed in method '{method_name}', " f"call_servable should not be invoked more than once") if has_called_postprocess_: raise RuntimeError(f"Check failed in method '{method_name}', " f"call_postprocess or call_postprocess_pipeline should be invoked after call_servable") has_called_servable_ = True if not g_declared_models: raise RuntimeError(f"There is no model declared, you can use declare_model to declare models.") outputs_count = _get_stage_outputs_count("call_servable") if len(g_declared_models) == 1: model = g_declared_models[0] else: raise RuntimeError( f"There are more than one servable declared when the interface 'call_servable' is used, use 'add_stage'" f" to replace 'call_servable'") return add_stage(model, *args, outputs_count=outputs_count) def add_stage(stage, *args, outputs_count, batch_size=None, tag=None): r"""In the `servable_config.py` file of one servable, we use `register_method` to wrap a Python function to define a `method` of the servable, and `add_stage` is used to define a stage of this `method`, which can be a Python function or a model. Note: The length of 'args' should be equal to the inputs number of function or model. Args: stage (Union(function, Model)): User-defined python function or `Model` object return by declare_model. outputs_count (int): Outputs count of the user-defined python function or model. batch_size (int, optional): This parameter is valid only when stage is a function and the function can process multi instances at a time. default ``None``. - ``None``, The input of the function will be the inputs of one instance. - ``0``, The input of the function will be tuple object of instances, and the maximum number of the instances is determined by the server based on the batch size of models. - int value >= 1, The input of the function will be tuple object of instances, and the maximum number of the instances is the value specified by 'batch_size'. args: Stage inputs placeholders, which come from the inputs of the function wrapped by register_method or the outputs of add_stage. The length of 'args' should equal to the input number of the function or model. tag (str, optional): Customized flag of the stage, such as ``"Preprocess"``, default ``None``. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. Examples: >>> import numpy as np >>> from mindspore_serving.server import register >>> add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") >>> >>> def preprocess(x1, x2): ... return x1.astype(np.float32), x2.astype(np.float32) >>> >>> @register.register_method(output_names=["y"]) # register add_common method in add >>> def add_common(x1, x2): ... x1, x2 = register.add_stage(preprocess, x1, x2, outputs_count=2) # call preprocess in stage 1 ... y = register.add_stage(add_model, x1, x2, outputs_count=1) # call add model in stage 2 ... return y """ global method_def_context_ global cur_stage_index_ method_name = method_def_context_.method_name if tag is not None: check_type.check_str("tag", tag) else: tag = "" for item in args: if not isinstance(item, _TensorDef): raise RuntimeError(f"Each value of parameter *args is a placeholder for data and must come from the method" f" inputs or the outputs of add_stage") func_inputs = [item.as_pair() for item in args] inputs_count = len(args) if isinstance(stage, Model): if stage not in g_declared_models: raise RuntimeError( f"Check failed in method '{method_name}', the parameter 'stage' of add_stage must be function " f"or Model returned by declare_model, and ensure that interface 'declare_model' can take effect " f"when importing servable_config.py by the serving server") model = stage model_key = model.model_key ServableRegister_.register_model_input_output_info(model_key, inputs_count, outputs_count, 0) method_def_context_.add_stage_model(model_key, func_inputs, 0, tag) elif inspect.isfunction(stage): if batch_size is None: register_stage_function(method_name, _wrap_fun_to_batch(stage, inputs_count), inputs_count=inputs_count, outputs_count=outputs_count, use_with_size=False) batch_size = 0 else: check_type.check_int("batch_size", batch_size, 0) register_stage_function(method_name, stage, inputs_count=inputs_count, outputs_count=outputs_count, use_with_size=True) func_name = get_servable_dir() + "." + get_func_name(stage) method_def_context_.add_stage_function(func_name, func_inputs, batch_size, tag) else: if not isinstance(stage, str): raise RuntimeError( f"Check failed in method '{method_name}', the parameter 'stage' of add_stage must be function " f"or Model returned by declare_model, now is {type(stage)}") func_name = stage check_stage_function(method_name, func_name, inputs_count=inputs_count, outputs_count=outputs_count) method_def_context_.add_stage_function(func_name, func_inputs, 0, tag) cur_stage_index_ += 1 # call_xxx stage index start begin 1 return _create_tensor_def_outputs(cur_stage_index_, outputs_count) _call_servable_name = call_servable.__name__ _call_stage_names = [call_preprocess.__name__, call_postprocess.__name__] _call_stage_batch_names = [call_preprocess_pipeline.__name__, call_postprocess_pipeline.__name__] def _ast_node_info(method_def_func, ast_node): """Ast node code info""" func_name = method_def_func.__name__ func_codes = inspect.getsource(method_def_func) func_codes_lines = func_codes.split("\n") _, start_lineno = inspect.findsource(method_def_func) codes = "" if hasattr(ast_node, "end_lineno"): end_lineno = ast_node.end_lineno else: end_lineno = ast_node.lineno for line in range(ast_node.lineno, end_lineno + 1): codes += func_codes_lines[line - 1] + "\n" lineno = ast_node.lineno + start_lineno end_lineno = end_lineno + start_lineno if lineno != end_lineno: line_info = f"{lineno}~{end_lineno}" else: line_info = f"{lineno}" return f"line {line_info} in {func_name}, code: \n" + codes def _get_method_def_stage_meta(method_def_func): """Parse register_method func, and get the input and output count of preprocess, servable and postprocess""" source = inspect.getsource(method_def_func) method_name = method_def_func.__name__ call_list = ast.parse(source).body[0].body func_meta = {} code_infos = [] code_other = None def update_other_code(code): nonlocal code_other if not code_other: code_other = code for call_item in call_list: if isinstance(call_item, ast.Return): continue if isinstance(call_item, ast.Expr): continue if not isinstance(call_item, ast.Assign): update_other_code(call_item) continue target = call_item.targets[0] if isinstance(target, ast.Name): outputs_count = 1 elif isinstance(target, ast.Tuple): outputs_count = len(target.elts) else: continue call = call_item.value if not isinstance(call, ast.Call): continue func = call.func if isinstance(func, ast.Attribute): func_name = func.attr elif isinstance(func, ast.Name): func_name = func.id else: update_other_code(call_item) continue inputs_count = len(call.args) if func_name in _call_stage_names or func_name in _call_stage_batch_names: inputs_count -= 1 elif func_name == _call_servable_name: pass else: update_other_code(call_item) continue if inputs_count <= 0: raise RuntimeError(f"Invalid '{func_name}' invoke args") logger.info(f"stage {len(func_meta) + 1} call type '{func_name}', inputs count {inputs_count}, " f"outputs count {outputs_count}") func_meta[func_name] = [inputs_count, outputs_count] code_infos.append([call_item, func_name]) if code_infos and code_other: call_names = [item[1] for item in code_infos] call_names = ";".join(call_names) raise RuntimeError( f"Failed to parse method '{method_name}', complex statements such as conditions and loops are not supported" f" in register_method when the interface '{call_names}' is used, use 'add_stage' to replace '{call_names}'," f" code {type(code_other)}: {_ast_node_info(method_def_func, code_other)}") if code_infos and _call_servable_name not in func_meta: raise RuntimeError(f"Not find the invoke of '{_call_servable_name}'") return func_meta def register_method(output_names): """Define a method of the servable when importing servable_config.py of one servable. One servable can include one or more methods, and eache method provides different services base on models. A client needs to specify the servable name and method name when accessing one service. MindSpore Serving supports a service consisting of multiple python functions and multiple models. Note: This interface should take effect when importing servable_config.py by the serving server. Therefore, it's recommended that this interface be used globally in servable_config.py. This interface will define the signatures and pipeline of the method. The signatures include the method name, input and outputs names of the method. When accessing a service, the client needs to specify the servable name, the method name, and provide one or more inference instances. Each instance specifies the input data by the input names and obtains the output data by the outputs names. The pipeline consists of one or more stages, each stage can be a python function or a model. This is, a pipline can include one or more python functions and one or more models. In addition, the interface also defines the data flow of these stages. Args: output_names (Union[str, tuple[str], list[str]]): The output names of method. The input names is the args names of the registered function. Raises: RuntimeError: The type or value of the parameters are invalid, or other error happened. Examples: >>> from mindspore_serving.server import register >>> add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") >>> sub_model = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR") >>> >>> @register.register_method(output_names=["y"]) # register predict method in servable >>> def predict(x1, x2, x3): # x1+x2-x3 ... y = register.add_stage(add_model, x1, x2, outputs_count=1) ... y = register.add_stage(sub_model, y, x3, outputs_count=1) ... return y """ output_names = check_type.check_and_as_str_tuple_list('output_names', output_names) def register(func): name = get_func_name(func) sig = inspect.signature(func) input_names = [] for k, v in sig.parameters.items(): if v.kind == inspect.Parameter.VAR_POSITIONAL: raise RuntimeError(f"'{name}' input {k} cannot be VAR_POSITIONAL !") if v.kind == inspect.Parameter.VAR_KEYWORD: raise RuntimeError(f"'{name}' input {k} cannot be VAR_KEYWORD !") if v.kind == inspect.Parameter.KEYWORD_ONLY: raise RuntimeError(f"'{name}' input {k} cannot be KEYWORD_ONLY !") input_names.append(k) input_tensors = [] for i in range(len(input_names)): input_tensors.append(_TensorDef(0, i)) servable_name = get_servable_dir() global method_def_context_ method_def_context_ = MethodSignature_() method_def_context_.servable_name = servable_name method_def_context_.method_name = name method_def_context_.inputs = input_names method_def_context_.outputs = output_names global method_def_ast_meta_ method_def_ast_meta_ = _get_method_def_stage_meta(func) global cur_stage_index_ cur_stage_index_ = 0 global has_called_preprocess_, has_called_servable_, has_called_postprocess_ has_called_preprocess_ = False has_called_servable_ = False has_called_postprocess_ = False output_tensors = func(*tuple(input_tensors)) if method_def_ast_meta_ and cur_stage_index_ != len(method_def_ast_meta_): raise RuntimeError(f"Failed to parse method {name}, the number of stages obtained through the AST " f"{len(method_def_ast_meta_)} is inconsistent with the running result {cur_stage_index_}" f". Condition and loop statements are not supported in methods currently.") if isinstance(output_tensors, _TensorDef): output_tensors = (output_tensors,) for item in output_tensors: if not isinstance(item, _TensorDef): raise RuntimeError(f"Each value returned is a placeholder for data and must come from the method" f" inputs or the outputs of add_stage") if len(output_tensors) != len(output_names): raise RuntimeError( f"Method return output size {len(output_tensors)} not match registered {len(output_names)}") return_inputs = [item.as_pair() for item in output_tensors] method_def_context_.set_return(return_inputs) logger.info(f"Register method: method_name {method_def_context_.method_name}, " f"inputs: {input_names}, outputs: {output_names}") ServableRegister_.register_method(method_def_context_) return func return register ================================================ FILE: mindspore_serving/server/register/model.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Servable declaration interface""" from mindspore_serving._mindspore_serving import ModelMeta_, ServableRegister_, ModelContext_ from mindspore_serving import log as logger from mindspore_serving.server.common import check_type, deprecated from .utils import get_servable_dir g_declared_models = [] @deprecated("1.5.0", "mindspore_serving.server.register.declare_model") def declare_servable(servable_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None): r""" declare one model. .. warning:: 'register.declare_servable' is deprecated from version 1.5.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.declare_model` instead. Args: servable_file (Union[str, list[str]]): Model files name. model_format (str): Model format, ``"OM"`` or ``"MindIR"``, case ignored. with_batch_dim (bool, optional): Whether the first shape dim of the inputs and outputs of model is batch dim. Default: ``True``. options (Union[AclOptions, GpuOptions], optional): Options of model, supports AclOptions or GpuOptions. Default: ``None``. without_batch_dim_inputs (Union[int, tuple[int], list[int]], optional): Index of inputs that without batch dim when `with_batch_dim` is ``True``. Default: ``None``. Raises: RuntimeError: The type or value of the parameters are invalid. Return: Model, identification of this model, used as input of add_stage. """ return declare_model(servable_file, model_format, with_batch_dim, options, without_batch_dim_inputs) class Model: """Indicate a model. User should not construct Model object directly, it's need to be returned from `declare_model` or `declare_servable` Args: model_key (str): Model key identifies the model. """ def __init__(self, model_key): self.model_key = model_key def call(self, *args, subgraph=0): r"""Invoke the model inference interface based on instances. Args: args : tuple/list of instances, or inputs of one instance. subgraph (int, optional): Subgraph index, used when there are multiply sub-graphs in one model. Default: ``0``. Return: Tuple of instances when input parameter 'args' is tuple/list, or outputs of one instance. Raises: RuntimeError: Inputs are invalid. Examples: >>> import numpy as np >>> from mindspore_serving.server import register >>> import mindspore.dataset.vision.c_transforms as VC >>> model = register.declare_model(model_file="resnet_bs32.mindir", model_format="MindIR") # batch_size=32 >>> >>> def preprocess(image): ... decode = VC.Decode() ... resize = VC.Resize([224, 224]) ... normalize = VC.Normalize(mean=[125.307, 122.961, 113.8575], std=[51.5865, 50.847, 51.255]) ... hwc2chw = VC.HWC2CHW() ... image = decode(image) ... image = resize(image) # [3,224,224] ... image = normalize(image) # [3,224,224] ... image = hwc2chw(image) # [3,224,224] ... return input >>> >>> def postprocess(score): >>> return np.argmax(score) >>> >>> def call_resnet_model(image): ... image = preprocess(image) ... score = model.call(image) # for only one instance ... return postprocess(score) >>> >>> def call_resnet_model_batch(instances): ... input_instances = [] ... for instance in instances: ... image = instance[0] # only one input ... image = preprocess(image) # [3,224,224] ... input_instances.append([image]) ... output_instances = model.call(input_instances) # for multiply instances ... for instance in output_instances: ... score = instance[0] # only one output for each instance ... index = postprocess(score) ... yield index >>> >>> @register.register_method(output_names=["index"]) >>> def predict_v1(image): # without pipeline, call model with only one instance a time ... index = register.add_stage(call_resnet_model, image, outputs_count=1) ... return index >>> >>> @register.register_method(output_names=["index"]) >>> def predict_v2(image): # without pipeline, call model with maximum 32 instances a time ... index = register.add_stage(call_resnet_model_batch, image, outputs_count=1, batch_size=32) ... return index >>> >>> @register.register_method(output_names=["index"]) >>> def predict_v3(image): # pipeline ... image = register.add_stage(preprocess, image, outputs_count=1) ... score = register.add_stage(model, image, outputs_count=1) ... index = register.add_stage(postprocess, score, outputs_count=1) ... return index """ check_type.check_int("subgraph", subgraph, 0) subgraph_str = "" if subgraph != 0: subgraph_str = " ,subgraph=" + str(subgraph) if not args: raise RuntimeError(f"Model({self.model_key}{subgraph_str}).call() failed: no inputs provided, the inputs " f"can be call(x1, x2) for single instance or call([[x1, x2], [x1, x2]]) for multi " f"instances.") instances = [] instance_format = False if len(args) == 1 and isinstance(args[0], (tuple, list)): instance_format = True inputs = args[0] for instance in inputs: if not isinstance(instance, (tuple, list)): raise RuntimeError(f"Model({self.model_key}{subgraph_str}).call() failed: inputs format invalid, " f"the inputs can be call(x1, x2) for single instance or " f" call([[x1, x2], [x1, x2]]) for multi instances.") instances.append(tuple(instance)) else: instances.append(tuple(args)) output = ServableRegister_.run(self.model_key, tuple(instances), subgraph) if not instance_format: output = output[0] if len(output) == 1: return output[0] return output return output def append_declared_model(model_key): global g_declared_models model = Model(model_key) g_declared_models.append(model) return model def declare_model(model_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None, context=None, config_file=None): r""" Declare one model when importing servable_config.py of one servable. Note: This interface should take effect when importing servable_config.py by the serving server. Therefore, it's recommended that this interface be used globally in servable_config.py. .. warning:: The parameter 'options' is deprecated from version 1.6.0 and will be removed in a future version, use parameter 'context' instead. Args: model_file (Union[str, list[str]]): Model files name. model_format (str): Model format, ``"MindIR"`` or ``"MindIR_Lite"``, case ignored. with_batch_dim (bool, optional): Whether the first shape dim of the inputs and outputs of model is batch dim. Default: ``True``. options (Union[AclOptions, GpuOptions], optional): Options of model, supports AclOptions or GpuOptions. Default: ``None``. context (Context): Context is used to store environment variables during execution. If the value is ``None``, Serving uses the default device context based on the deployed device. Default: ``None``. without_batch_dim_inputs (Union[int, tuple[int], list[int]], optional): Index of inputs that without batch dim when `with_batch_dim` is ``True``. For example, if the shape of input 0 does not include the batch dimension, `without_batch_dim_inputs` can be set to `(0,)`. Default: ``None``. config_file (str, optional): Config file for model to set mix precision inference. The file path can be an absolute path or a relative path to the directory in which servable_config.py resides. Default: ``None``. Return: Model, identification of this model, can be used for `Model.call` or as the inputs of `add_stage`. Raises: RuntimeError: The type or value of the parameters are invalid. """ check_type.check_bool('with_batch_dim', with_batch_dim) meta = ModelMeta_() model_file = check_type.check_and_as_str_tuple_list('model_file', model_file) meta.common_meta.servable_name = get_servable_dir() meta.common_meta.model_key = ";".join(model_file) meta.common_meta.with_batch_dim = with_batch_dim if without_batch_dim_inputs: without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', without_batch_dim_inputs, 0) meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs # init local servable meta info check_type.check_str('model_format', model_format) model_format = model_format.lower() if model_format not in ("om", "mindir", "mindir_opt", "mindir_lite"): raise RuntimeError("model format can only be OM, MindIR or MindIR_Lite, case ignored") meta.local_meta.model_file = model_file meta.local_meta.set_model_format(model_format) if context is not None: if not isinstance(context, Context): raise RuntimeError(f"Parameter 'context' should be Context, but gotten {type(context)}") meta.local_meta.model_context = context.model_context elif isinstance(options, (GpuOptions, AclOptions)): logger.warning( "'options' will be deprecated in the future, we recommend using 'context', if these two parameters " "are both set, options will be ignored") meta.local_meta.model_context = options.context.model_context elif options is not None: raise RuntimeError(f"Parameter 'options' should be None, GpuOptions or AclOptions, but " f"gotten {type(options)}") if config_file is not None: check_type.check_str("config_file", config_file) meta.local_meta.config_file = config_file ServableRegister_.declare_model(meta) logger.info(f"Declare model, model_file: {model_file} , model_format: {model_format}, with_batch_dim: " f"{with_batch_dim}, options: {options}, without_batch_dim_inputs: {without_batch_dim_inputs}" f", context: {context}, config file: {config_file}") return append_declared_model(meta.common_meta.model_key) class Context: """ Context is used to customize device configurations. If Context is not specified, MindSpore Serving uses the default device configurations. When inference backend is MindSpore Lite and the device type is Ascend or Gpu, the extra `CPUDeviceInfo` will be used. Args: thread_num (int, optional): Set the number of threads at runtime. Only valid when using mindspore lite. thread_affinity_core_list (tuple[int], list[int], optional): Set the thread lists to CPU cores. Only valid when inference backend is MindSpore Lite. enable_parallel (bool, optional): Set the status whether to perform model inference or training in parallel. Only valid when inference backend is MindSpore Lite. Raises: RuntimeError: type or value of input parameters are invalid. Examples: >>> from mindspore_serving.server import register >>> import numpy as np >>> context = register.Context(thread_num=1, thread_affinity_core_list=[1,2], enable_parallel=True) >>> context.append_device_info(register.GPUDeviceInfo(precision_mode="fp16")) >>> model = declare_model(model_file="tensor_add.mindir", model_format="MindIR", context=context) """ def __init__(self, **kwargs): self.model_context = ModelContext_() val_set_fun = { "thread_num": self._set_thread_num, "thread_affinity_core_list": self._set_thread_affinity_core_list, "enable_parallel": self._set_enable_parallel } for k, v in kwargs.items(): if k not in val_set_fun: raise RuntimeError("Set context failed, unsupported option " + k) val_set_fun[k](v) self.device_types = [] def append_device_info(self, device_info): """Append one user-defined device info to the context Args: device_info (Union[CPUDeviceInfo, GPUDeviceInfo, AscendDeviceInfo]): User-defined device info for one device, otherwise default values are used. You can customize device info for each device, and the system selects the required device info based on the actual backend device and MindSpore inference package. Raises: RuntimeError: type or value of input parameters are invalid. """ if not isinstance(device_info, DeviceInfoContext): raise RuntimeError(f"Parameter 'device_info' should instance of CPUDeviceInfo, GPUDeviceInfo, or " f"AscendDeviceInfo, but actually {type(device_info)}") # pylint: disable=protected-access info_map = device_info._as_context_map() if not info_map["device_type"]: raise RuntimeError("Invalid DeviceInfoContext, device_type cannot be empty") device_type = info_map["device_type"] if device_type in self.device_types: raise RuntimeError(f"Device info of type {device_type} has already been appended") self.device_types.append(device_type) self.model_context.append_device_info(info_map) def _set_thread_num(self, val): check_type.check_int("thread_num", val, 1) self.model_context.thread_num = val def _set_thread_affinity_core_list(self, val): check_type.check_int_tuple_list("thread_affinity_core_list", val, 0) self.model_context.thread_affinity_core_list = val def _set_enable_parallel(self, val): check_type.check_bool("enable_parallel", val) if val: self.model_context.enable_parallel = 1 else: self.model_context.enable_parallel = 0 def __str__(self): res = f"thread_num: {self.model_context.thread_num}, thread_affinity_core_list: " \ f"{self.model_context.thread_affinity_core_list}, enable_parallel: " \ f"{self.model_context.enable_parallel}, device_list, {self.model_context.device_list}" return res class DeviceInfoContext: def __init__(self): """ Initialize context""" def _as_context_map(self): """Transfer device info to dict of str,str""" raise NotImplementedError class CPUDeviceInfo(DeviceInfoContext): """ Helper class to set cpu device info. Args: precision_mode(str, optional): Option of model precision, and the value can be ``"origin"``, ``"fp16"``. ``"origin"`` indicates that inference is performed with the preciesion defined in the model, and ``"fp16"`` indicates that inference is performed based on FP16 precision. Default: ``"origin"``. Raises: RuntimeError: Cpu option is invalid, or value is not str. Examples: >>> from mindspore_serving.server import register >>> context = register.Context() >>> context.append_device_info(register.CPUDeviceInfo(precision_mode="fp16")) >>> model = register.declare_model(model_file="deeptext.ms", model_format="MindIR_Lite", context=context) """ def __init__(self, **kwargs): super(CPUDeviceInfo, self).__init__() self.precision_mode = "" val_set_fun = {"precision_mode": self._set_precision_mode} for k, w in kwargs.items(): if k not in val_set_fun: raise RuntimeError("Set cpu device info failed, unsupported option " + k) val_set_fun[k](w) self.context_map = self._as_context_map() def _set_precision_mode(self, val): check_type.check_str("precision_mode", val) if val not in ("origin", "fp16"): raise RuntimeError(f"Cpu device info 'precision_mode' can only be 'origin', 'fp16'. given '{val}'") self.precision_mode = val def _as_context_map(self): """Transfer cpu device info to dict of str,str""" context_map = {} if self.precision_mode: context_map["precision_mode"] = self.precision_mode context_map["device_type"] = "cpu" return context_map class GPUDeviceInfo(DeviceInfoContext): """ Helper class to set gpu device info. Args: precision_mode(str, optional): Option of model precision, and the value can be ``"origin"``, ``"fp16"``. ``"origin"`` indicates that inference is performed with the preciesion defined in the model, and ``"fp16"`` indicates that inference is performed based on FP16 precision. Default: ``"origin"``. Raises: RuntimeError: Gpu option is invalid, or value is not str. Examples: >>> from mindspore_serving.server import register >>> context = register.Context() >>> context.append_device_info(register.GPUDeviceInfo(precision_mode="fp16")) >>> model = register.declare_model(model_file="deeptext.mindir", model_format="MindIR", context=context) """ def __init__(self, **kwargs): super(GPUDeviceInfo, self).__init__() self.precision_mode = "" val_set_fun = {"precision_mode": self._set_precision_mode} for k, w in kwargs.items(): if k not in val_set_fun: raise RuntimeError("Set gpu device info failed, unsupported option " + k) val_set_fun[k](w) self.context_map = self._as_context_map() def _set_precision_mode(self, val): """Set option 'precision_mode', which means inference operator selection, and the value can be "origin", "fp16", default "origin". Args: val (str): Value of option 'precision_mode'. "origin" inference with model definition. "fp16" enable FP16 operator selection, with FP32 fallback. Default: "origin". Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('precision_mode', val) if val not in ("origin", "fp16"): raise RuntimeError(f"Gpu device info 'precision_mode' can only be 'origin', 'fp16'. given '{val}'") self.precision_mode = val def _as_context_map(self): """Transfer gpu device info to dict of str,str""" context_map = {} if self.precision_mode: context_map["precision_mode"] = self.precision_mode context_map["device_type"] = "gpu" return context_map class AscendDeviceInfo(DeviceInfoContext): """ Helper class to set Ascend device infos. Args: insert_op_cfg_path (str, optional): Path of aipp config file. input_format (str, optional): Manually specify the model input format, the value can be ``"ND"``, ``"NCHW"``, ``"NHWC"``, ``"CHWN"``, ``"NC1HWC0"``, or ``"NHWC1C0"``. input_shape (str, optional): Manually specify the model input shape, such as ``"input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"``. output_type (str, optional): Manually specify the model output type, the value can be ``"FP16"``, ``"UINT8"`` or ``"FP32"``. Default: ``"FP32"``. precision_mode (str, optional): Model precision mode, the value can be ``"force_fp16"``, ``"allow_fp32_to_fp16"``, ``"must_keep_origin_dtype"`` or ``"allow_mix_precision"``. Default: ``"force_fp16"``. op_select_impl_mode (str, optional): The operator selection mode, the value can be ``"high_performance"`` or ``"high_precision"``. Default: ``"high_performance"``. fusion_switch_config_path (str, optional): Configuration file path of the convergence rule, including graph convergence and UB convergence. The system has built-in graph convergence and UB convergence rules, which are enableed by default. You can disable the rules specified in the file by setting this parameter. buffer_optimize_mode (str, optional): The value can be ``"l1_optimize"``, ``"l2_optimize"``, ``"off_optimize"`` or ``"l1_and_l2_optimize"``. Default: ``"l2_optimize"``. Raises: RuntimeError: Ascend device info is invalid. Examples: >>> from mindspore_serving.server import register >>> context = register.Context() >>> context.append_device_info(register.AscendDeviceInfo(input_format="NCHW")) >>> model = register.declare_model(model_file="deeptext.ms", model_format="MindIR_Lite", context=context) """ def __init__(self, **kwargs): super(AscendDeviceInfo, self).__init__() self.insert_op_cfg_path = "" self.input_format = "" self.input_shape = "" self.output_type = "" self.precision_mode = "" self.op_select_impl_mode = "" self.fusion_switch_config_path = "" self.buffer_optimize_mode = "" val_set_fun = {"insert_op_cfg_path": self._set_insert_op_cfg_path, "input_format": self._set_input_format, "input_shape": self._set_input_shape, "output_type": self._set_output_type, "precision_mode": self._set_precision_mode, "op_select_impl_mode": self._set_op_select_impl_mode, "fusion_switch_config_path": self._set_fusion_switch_config_path, "buffer_optimize_mode": self._set_buffer_optimize_mode} for k, w in kwargs.items(): if k not in val_set_fun: raise RuntimeError("Set ascend device info failed, unsupported parameter " + k) val_set_fun[k](w) self.context_map = self._as_context_map() def _set_insert_op_cfg_path(self, val): """Set option 'insert_op_cfg_path' Args: val (str): Value of option 'insert_op_cfg_path'. Raises: RuntimeError: The type of value is not str. """ check_type.check_str('insert_op_cfg_path', val) self.insert_op_cfg_path = val def _set_input_format(self, val): """Set option 'input_format', manually specify the model input format, and the value can be "ND", "NCHW", "NHWC", "CHWN", "NC1HWC0", or "NHWC1C0". Args: val (str): Value of option 'input_format', and the value can be "ND", "NCHW", "NHWC", "CHWN", "NC1HWC0", or "NHWC1C0". Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('input_format', val) if val not in ("ND", "NCHW", "NHWC", "CHWN", "NC1HWC0", "NHWC1C0"): raise RuntimeError(f"Ascend device info 'input_format' can only be 'ND', 'NCHW', 'NHWC', 'CHWN', 'NC1HWC0'" f", or 'NHWC1C0', actually given '{val}'") self.input_format = val def _set_input_shape(self, val): """Set option 'input_shape', manually specify the model input shape, such as "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1". Args: val (str): Value of option 'input_shape'. Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('input_shape', val) self.input_shape = val def _set_output_type(self, val): """Set option 'output_type', manually specify the model output type, and the value can be "FP16", "UINT8", or "FP32", default "FP32". Args: val (str): Value of option 'output_type', and the value can be "FP16", "UINT8", or "FP32", default "FP32". Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('output_type', val) if val not in ("FP32", "FP16", "UINT8"): raise RuntimeError(f"Ascend device info 'op_select_impl_mode' can only be 'FP32'(default), 'FP16' or " f"'UINT8', actually given '{val}'") self.output_type = val def _set_precision_mode(self, val): """Set option 'precision_mode', which means operator selection mode, and the value can be "force_fp16", "force_fp16", "must_keep_origin_dtype", or "allow_mix_precision", default "force_fp16". Args: val (str): Value of option 'precision_mode', and the value can be "force_fp16", "force_fp16", "must_keep_origin_dtype", or "allow_mix_precision", default "force_fp16". Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('precision_mode', val) if val not in ("force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", "allow_mix_precision"): raise RuntimeError(f"Ascend device info 'precision_mode' can only be 'force_fp16'(default), " f"'allow_fp32_to_fp16' 'must_keep_origin_dtype' or 'allow_mix_precision', " f"actually given '{val}'") self.precision_mode = val def _set_op_select_impl_mode(self, val): """Set option 'op_select_impl_mode', which means model precision mode, and the value can be "high_performance" or "high_precision", default "high_performance". Args: val (str): Value of option 'op_select_impl_mode',which can be "high_performance" or "high_precision", default "high_performance". Raises: RuntimeError: The type of value is not str, or the value is invalid. """ check_type.check_str('op_select_impl_mode', val) if val not in ("high_performance", "high_precision"): raise RuntimeError(f"Ascend device info 'op_select_impl_mode' can only be 'high_performance'(default) or " f"'high_precision', actually given '{val}'") self.op_select_impl_mode = val def _set_fusion_switch_config_path(self, val): check_type.check_str('fusion_switch_config_path', val) self.fusion_switch_config_path = val def _set_buffer_optimize_mode(self, val): check_type.check_str('buffer_optimize_mode', val) if val not in ("l1_optimize", "l2_optimize", "off_optimize", "l1_and_l2_optimize"): raise RuntimeError(f"Ascend device info 'buffer_optimize_mode' can only be 'off_optimize'(default), " f"'l1_optimize', 'l2_optimize' or 'l1_and_l2_optimize', actually given '{val}'") self.buffer_optimize_mode = val def _as_context_map(self): """Transfer acl device info to dict of str,str""" context_map = {} if self.insert_op_cfg_path: context_map["insert_op_cfg_path"] = self.insert_op_cfg_path if self.input_format: context_map["input_format"] = self.input_format if self.input_shape: context_map["input_shape"] = self.input_shape if self.output_type: context_map["output_type"] = self.output_type if self.precision_mode: context_map["precision_mode"] = self.precision_mode if self.op_select_impl_mode: context_map["op_select_impl_mode"] = self.op_select_impl_mode if self.buffer_optimize_mode: context_map["buffer_optimize_mode"] = self.buffer_optimize_mode if self.fusion_switch_config_path: context_map["fusion_switch_config_path"] = self.fusion_switch_config_path context_map["device_type"] = "ascend" return context_map class AclOptions: """ Helper class to set Ascend device infos. .. warning:: 'AclOptions' is deprecated from version 1.6.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.AscendDeviceInfo` instead. Args: insert_op_cfg_path (str, optional): Path of aipp config file. input_format (str, optional): Manually specify the model input format, the value can be ``"ND"``, ``"NCHW"``, ``"NHWC"``, ``"CHWN"``, ``"NC1HWC0"``, or ``"NHWC1C0"``. input_shape (str, optional): Manually specify the model input shape, such as ``"input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"``. output_type (str, optional): Manually specify the model output type, the value can be ``"FP16"``, ``"UINT8"`` or ``"FP32"``. Default: ``"FP32"``. precision_mode (str, optional): Model precision mode, the value can be ``"force_fp16"``, ``"allow_fp32_to_fp16"``, ``"must_keep_origin_dtype"`` or ``"allow_mix_precision"``. Default: ``"force_fp16"``. op_select_impl_mode (str, optional): The operator selection mode, the value can be ``"high_performance"`` or ``"high_precision"``. Default: ``"high_performance"``. Raises: RuntimeError: Acl option is invalid, or value is not str. Examples: >>> from mindspore_serving.server import register >>> options = register.AclOptions(op_select_impl_mode="high_precision", precision_mode="allow_fp32_to_fp16") >>> register.declare_servable(servable_file="deeptext.mindir", model_format="MindIR", options=options) """ def __init__(self, **kwargs): super(AclOptions, self).__init__() logger.warning("'AclOptions' is deprecated from version 1.6.0 and will be removed in a future version, " "use 'mindspore_serving.server.register.AscendDeviceInfo' instead.") device_info = AscendDeviceInfo(**kwargs) self.context = Context() self.context.append_device_info(device_info) class GpuOptions: """ Helper class to set gpu options. .. warning:: 'GpuOptions' is deprecated from version 1.6.0 and will be removed in a future version, use :class:`mindspore_serving.server.register.GPUDeviceInfo` instead. Args: precision_mode(str, optional): inference operator selection, and the value can be ``"origin"``, ``"fp16"``. Default: ``"origin"``. Raises: RuntimeError: Gpu option is invalid, or value is not str. Examples: >>> from mindspore_serving.server import register >>> options = register.GpuOptions(precision_mode="origin") >>> register.declare_servable(servable_file="deeptext.mindir", model_format="MindIR", options=options) """ def __init__(self, **kwargs): super(GpuOptions, self).__init__() logger.warning("'GpuOptions' is deprecated from version 1.6.0 and will be removed in a future version, " "use 'mindspore_serving.server.register.GPUDeviceInfo' instead.") device_info = GPUDeviceInfo(**kwargs) self.context = Context() self.context.append_device_info(device_info) ================================================ FILE: mindspore_serving/server/register/stage_function.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Postprocessing registration interface""" from mindspore_serving._mindspore_serving import StageFunctionStorage_ from mindspore_serving import log as logger from .utils import get_servable_dir, get_func_name def check_stage_function(method_name, function_name, inputs_count, outputs_count): """Check whether inputs and outputs count is equal with last registered""" func_info = get_stage_info(function_name) if not func_info: return last_inputs_count, last_output_count = func_info if last_inputs_count != inputs_count: raise RuntimeError(f"Stage function '{function_name}' inputs count {inputs_count} not match " f"last registered count {last_inputs_count}, method name '{method_name}'") if last_output_count != outputs_count: raise RuntimeError(f"Stage function '{function_name}' outputs count {outputs_count} not match " f"last registered count {last_output_count}, method name '{method_name}'") def get_stage_info(function_name): """Get cpp and python function inputs and outputs count""" func_info = StageFunctionStorage_.get_instance().get_pycpp_function_info(function_name) if not func_info: return None return func_info class StageFunctionStorage: """Register and get stage function info: func, name, input and output count""" def __init__(self): self.function = {} self.storage = StageFunctionStorage_.get_instance() def register(self, method_name, fun, function_name, inputs_count, outputs_count, use_with_size): check_stage_function(method_name, function_name, inputs_count, outputs_count) if function_name in self.function: if self.function[function_name]["use_with_size"] != use_with_size: raise RuntimeError(f"Failed to add stage function {function_name}: parameter 'batch_size' in " f"multiple 'add_stage' should be enabled or disabled consistently") self.function[function_name] = {"fun": fun, "inputs_count": inputs_count, "outputs_count": outputs_count, "use_with_size": use_with_size} self.storage.register(function_name, inputs_count, outputs_count) def get(self, function_name): func = self.function.get(function_name, None) if func is None: raise RuntimeError(f"Stage function '{function_name}' not found") return func stage_function_storage = StageFunctionStorage() def register_stage_function(method_name, func, inputs_count, outputs_count, use_with_size): """register stage function""" servable_name = get_servable_dir() func_name = get_func_name(func) name = servable_name + "." + func_name logger.info(f"Register stage function {name} {inputs_count} {outputs_count}, use batch size: {use_with_size}") stage_function_storage.register(method_name, func, name, inputs_count, outputs_count, use_with_size) ================================================ FILE: mindspore_serving/server/register/utils.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Common implement for worker""" import inspect import os def get_servable_dir(): """Get the directory where servable is located. The name of the directory is the name of servable""" stack = inspect.stack() for item in stack: if item.filename.endswith("servable_config.py"): abs_path = os.path.realpath(item.filename) last_dir = os.path.split(abs_path)[0] last_dir = os.path.split(last_dir)[1] if not last_dir: continue return last_dir raise RuntimeError("Failed to obtain the directory of servable_config.py") def get_func_name(func): """Get function name for preprocess and postprocess, as the identification name""" return func.__name__ ================================================ FILE: mindspore_serving/server/start_extra_worker.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start worker process with single core servable""" import os import signal import argparse from mindspore_serving.server import worker from mindspore_serving.server.common import check_type from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ def start_extra_worker(servable_directory, servable_name, version_number, device_type, device_ids_empty, index, master_address, dec_key, dec_mode, listening_master, enable_lite): """Start worker process with single core servable""" signal.signal(signal.SIGCHLD, signal.SIG_DFL) # for ccec compiler check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) check_type.check_str('device_type', device_type) check_type.check_bool('device_ids_empty', device_ids_empty) check_type.check_int('index', index, 0) check_type.check_str('master_address', master_address) check_type.check_bool('listening_master', listening_master) check_type.check_bool('enable_lite', enable_lite) ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message worker_pid = os.getpid() unix_socket_dir = "unix_socket_files" try: os.mkdir(unix_socket_dir) except FileExistsError: pass worker_address = f"unix:{unix_socket_dir}/serving_worker_{servable_name}_version{version_number}_extra{index}" \ f"_{worker_pid}" if len(worker_address) > 90: # limit maximum unix domain socket address length worker_address = worker_address[:40] + "___" + worker_address[-40:] try: worker.start_extra_servable(servable_directory=servable_directory, servable_name=servable_name, version_number=version_number, device_type=device_type, device_ids_empty=device_ids_empty, dec_key=dec_key, dec_mode=dec_mode, master_address=master_address, worker_address=worker_address, enable_lite=enable_lite) except Exception as ex: Worker_.notify_failed(master_address, f"{{servable:{servable_name}, version:{version_number}, extra:{index}, <{ex}>}}") raise def parse_args_and_start(): """Parse args and start distributed worker""" parser = argparse.ArgumentParser(description="Serving start extra worker") parser.add_argument('--servable_directory', type=str, required=True, help="servable directory") parser.add_argument('--servable_name', type=str, required=True, help="servable name") parser.add_argument('--version_number', type=int, required=True, help="version numbers") parser.add_argument('--device_type', type=str, required=True, help="device type") parser.add_argument('--device_ids_empty', type=str, required=True, help="device id") parser.add_argument('--index', type=int, required=True, help="device id") parser.add_argument('--enable_lite', type=str, required=True, help="enable lite") parser.add_argument('--master_address', type=str, required=True, help="master address") parser.add_argument('--dec_key_pipe_file', type=str, required=True, help="dec key pipe file") parser.add_argument('--dec_mode', type=str, required=True, help="dec mode") parser.add_argument('--listening_master', type=str, required=True, help="whether listening master") args = parser.parse_args() servable_directory = args.servable_directory servable_name = args.servable_name version_number = int(args.version_number) device_type = args.device_type # pylint: disable=simplifiable-if-expression device_ids_empty = True if args.device_ids_empty.lower() == "true" else False index = int(args.index) master_address = args.master_address dec_key_pipe = args.dec_key_pipe_file if dec_key_pipe != "None": with open(dec_key_pipe, "rb") as fp: dec_key = fp.read() prefix = "serving_temp_dec_" if dec_key_pipe[:len(prefix)] == prefix: os.remove(dec_key_pipe) else: dec_key = None dec_mode = args.dec_mode # pylint: disable=simplifiable-if-expression listening_master = True if args.listening_master.lower() == "true" else False # pylint: disable=simplifiable-if-expression enable_lite = True if args.enable_lite.lower() == "true" else False start_extra_worker(servable_directory, servable_name, version_number, device_type, device_ids_empty, index, master_address, dec_key, dec_mode, listening_master, enable_lite) if __name__ == '__main__': parse_args_and_start() ================================================ FILE: mindspore_serving/server/start_worker.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Start worker process with single core servable""" import os import signal import argparse from mindspore_serving.server import worker from mindspore_serving.server.common import check_type from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ def start_worker(servable_directory, servable_name, version_number, device_type, device_id, master_address, dec_key, dec_mode, listening_master, enable_lite): """Start worker process with single core servable""" signal.signal(signal.SIGCHLD, signal.SIG_DFL) # for ccec compiler check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) check_type.check_str('device_type', device_type) check_type.check_int('device_id', device_id, 0) check_type.check_str('master_address', master_address) check_type.check_bool('listening_master', listening_master) check_type.check_bool('enable_lite', enable_lite) ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message # for servable_config.py to get device id of current worker. os.environ["SERVING_DEVICE_ID"] = str(device_id) worker_pid = os.getpid() unix_socket_dir = "unix_socket_files" try: os.mkdir(unix_socket_dir) except FileExistsError: pass worker_address = f"unix:{unix_socket_dir}/serving_worker_{servable_name}_device{device_id}_{worker_pid}" if len(worker_address) > 90: # limit maximum unix domain socket address length worker_address = worker_address[:40] + "___" + worker_address[-40:] try: worker.start_servable(servable_directory=servable_directory, servable_name=servable_name, version_number=version_number, device_type=device_type, device_id=device_id, master_address=master_address, worker_address=worker_address, dec_key=dec_key, dec_mode=dec_mode, enable_lite=enable_lite) except Exception as ex: Worker_.notify_failed(master_address, f"{{servable name:{servable_name}, device id:{device_id}, <{ex}>}}") raise def parse_args_and_start(): """Parse args and start distributed worker""" parser = argparse.ArgumentParser(description="Serving start extra worker") parser.add_argument('--servable_directory', type=str, required=True, help="servable directory") parser.add_argument('--servable_name', type=str, required=True, help="servable name") parser.add_argument('--version_number', type=int, required=True, help="version numbers") parser.add_argument('--device_type', type=str, required=True, help="device type") parser.add_argument('--device_id', type=str, required=True, help="device id") parser.add_argument('--master_address', type=str, required=True, help="master address") parser.add_argument('--enable_lite', type=str, required=True, help="enable lite") parser.add_argument('--dec_key_pipe_file', type=str, required=True, help="dec key pipe file") parser.add_argument('--dec_mode', type=str, required=True, help="dec mode") parser.add_argument('--listening_master', type=str, required=True, help="whether listening master") args = parser.parse_args() servable_directory = args.servable_directory servable_name = args.servable_name version_number = int(args.version_number) device_type = args.device_type device_id = int(args.device_id) master_address = args.master_address dec_key_pipe = args.dec_key_pipe_file if dec_key_pipe != "None": with open(dec_key_pipe, "rb") as fp: dec_key = fp.read() prefix = "serving_temp_dec_" if dec_key_pipe[:len(prefix)] == prefix: os.remove(dec_key_pipe) else: dec_key = None dec_mode = args.dec_mode # pylint: disable=simplifiable-if-expression listening_master = True if args.listening_master.lower() == "true" else False # pylint: disable=simplifiable-if-expression enable_lite = True if args.enable_lite.lower() == "true" else False start_worker(servable_directory, servable_name, version_number, device_type, device_id, master_address, dec_key, dec_mode, listening_master, enable_lite) if __name__ == '__main__': parse_args_and_start() ================================================ FILE: mindspore_serving/server/worker/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore worker implement""" from ._worker import start_servable, start_extra_servable, stop, get_newest_version_number ================================================ FILE: mindspore_serving/server/worker/_worker.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Interface for start up servable""" import os import sys from functools import wraps from mindspore_serving import log as logger from mindspore_serving.server.common import check_type, get_abs_path from mindspore_serving.server.worker import init_mindspore from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ from mindspore_serving._mindspore_serving import ServableContext_ from .task import _start_py_task _wait_and_clear_thread = None def _set_enable_lite(enable_lite): """Set device id, default 0""" ServableContext_.get_instance().set_enable_lite(enable_lite) def _set_device_id(device_id): """Set device id, default 0""" ServableContext_.get_instance().set_device_id(device_id) def _set_device_type(device_type): """Set device type, now can be 'None'(default), 'GPU' and 'Ascend', 'Davinci'(same as 'Ascend'), case ignored. """ if device_type is not None: check_type.check_str('device_type', device_type) ServableContext_.get_instance().set_device_type_str(device_type) else: ServableContext_.get_instance().set_device_type_str('None') # depend on MindSpore build target def get_newest_version_number(servable_directory, servable_name): """Get newest version number of servable""" max_version = 0 servable_directory = get_abs_path(servable_directory) version_root_dir = os.path.join(servable_directory, servable_name) try: files = os.listdir(version_root_dir) except FileNotFoundError: return 0 for file in files: if not os.path.isdir(os.path.join(version_root_dir, file)): continue if not file.isdigit() or file == "0" and str(int(file)) != file: continue version = int(file) if max_version < version: max_version = version return max_version def stop(): r""" Stop the running of worker. Examples: >>> import os >>> from mindspore_serving import server >>> >>> servable_dir = os.path.abspath(".") >>> config = server.ServableConfig(servable_dir, "lenet", device_ids=0) >>> server.start_servables(servable_configs=config) >>> server.start_grpc_server("0.0.0.0:5500") >>> ... >>> server.stop() """ Worker_.stop_and_clear() def stop_on_except(func): """Wrap of clear environment and exit on Serving exception""" @wraps(func) def handle_except(*args, **kwargs): try: ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() raise return handle_except def _load_servable_config(servable_directory, servable_name): """Load servable config named servable_config.py in directory `servable_directory`/`servable_name` """ config_dir = os.path.join(servable_directory, servable_name) if not os.path.isdir(config_dir): raise RuntimeError(f"Load servable config failed, directory '{config_dir}' not exist, " f"servable directory '{servable_directory}', servable name '{servable_name}'") config_file = os.path.join(config_dir, "servable_config.py") if not os.path.isfile(config_file): raise RuntimeError(f"Load servable config failed, file '{config_file}' not exist, " f"servable directory '{servable_directory}', servable name '{servable_name}'") sys.path.append(servable_directory) try: __import__(servable_name + ".servable_config") except Exception as e: logger.error(f"import {servable_name}.servable_config failed, {str(e)}") raise RuntimeError(f"import {servable_name}.servable_config failed, {str(e)}") @stop_on_except def start_servable(servable_directory, servable_name, version_number, device_type, device_id, master_address, worker_address, dec_key, dec_mode, enable_lite): r""" Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master through gRPC master_address and worker_address. """ check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) check_type.check_int('device_id', device_id, 0) check_type.check_str('master_address', master_address) check_type.check_str('worker_address', worker_address) if dec_key is not None: check_type.check_bytes('dec_key', dec_key) else: dec_key = '' check_type.check_str('dec_mode', dec_mode) check_type.check_bool('enable_lite', enable_lite) _set_enable_lite(enable_lite) _load_servable_config(servable_directory, servable_name) model_names = Worker_.get_declared_model_names() if model_names: init_mindspore.init_mindspore_cxx_env(enable_lite) newest_version_number = get_newest_version_number(servable_directory, servable_name) if not newest_version_number: raise RuntimeError( f"There is no valid version directory of models while there are models declared in servable_config.py, " f"servable directory: {servable_directory}, servable name: {servable_name}") if version_number == 0: version_number = 1 _set_device_type(device_type) _set_device_id(device_id) Worker_.start_servable(servable_directory, servable_name, version_number, master_address, worker_address, dec_key, dec_mode) _start_py_task() @stop_on_except def start_extra_servable(servable_directory, servable_name, version_number, device_type, device_ids_empty, dec_key, dec_mode, master_address, worker_address, enable_lite): r""" Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master through gRPC master_address and worker_address. """ check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) check_type.check_str('device_type', device_type) check_type.check_bool('device_ids_empty', device_ids_empty) check_type.check_str('master_address', master_address) check_type.check_str('worker_address', worker_address) if dec_key is not None: check_type.check_bytes('dec_key', dec_key) else: dec_key = '' check_type.check_str('dec_mode', dec_mode) check_type.check_bool('enable_lite', enable_lite) _set_enable_lite(enable_lite) _load_servable_config(servable_directory, servable_name) model_names = Worker_.get_declared_model_names() if model_names: init_mindspore.init_mindspore_cxx_env(enable_lite) newest_version_number = get_newest_version_number(servable_directory, servable_name) if not newest_version_number: raise RuntimeError( f"There is no valid version directory of models while there are models declared in servable_config.py, " f"servable directory: {servable_directory}, servable name: {servable_name}") if version_number == 0: version_number = 1 _set_device_type(device_type) Worker_.start_extra_servable(servable_directory, servable_name, version_number, device_ids_empty, dec_key, dec_mode, master_address, worker_address) _start_py_task() ================================================ FILE: mindspore_serving/server/worker/check_version.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """version and config check""" import os import sys import subprocess from pathlib import Path from packaging import version from mindspore_serving import log as logger class AscendEnvChecker: """ascend environment check""" def __init__(self): atlas_nnae_version = "/usr/local/Ascend/nnae/latest/compiler/version.info" atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/compiler/version.info" hisi_fwk_version = "/usr/local/Ascend/latest/compiler/version.info" if os.path.exists(atlas_nnae_version): # atlas default path self.fwk_path = "/usr/local/Ascend/nnae/latest" self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/built-in/op_impl/ai_core/tbe" self.tbe_path = self.fwk_path + "/lib64" self.cce_path = self.fwk_path + "/compiler/ccec_compiler/bin" self.fwk_version = atlas_nnae_version self.op_path = "/usr/local/Ascend/nnae/latest/opp" self.aicpu_path = "/usr/local/Ascend/nnae/latest" elif os.path.exists(atlas_toolkit_version): # atlas default path self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest" self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe" self.tbe_path = self.fwk_path + "/lib64" self.cce_path = self.fwk_path + "/compiler/ccec_compiler/bin" self.fwk_version = atlas_toolkit_version self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" elif os.path.exists(hisi_fwk_version): # hisi default path self.fwk_path = "/usr/local/Ascend/latest" self.op_impl_path = "/usr/local/Ascend/latest/opp/built-in/op_impl/ai_core/tbe" self.tbe_path = self.fwk_path + "/lib64" self.cce_path = self.fwk_path + "/compiler/ccec_compiler/bin" self.fwk_version = hisi_fwk_version self.op_path = "/usr/local/Ascend/latest/opp" self.aicpu_path = "/usr/local/Ascend/latest" else: # custom or unknown environment self.fwk_path = "" self.op_impl_path = "" self.tbe_path = "" self.cce_path = "" self.fwk_version = "" self.op_path = "" self.aicpu_path = "" # env self.path = os.getenv("PATH") self.python_path = os.getenv("PYTHONPATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH") self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH") # check content self.path_check = "/compiler/ccec_compiler/bin" self.python_path_check = "opp/built-in/op_impl/ai_core/tbe" self.ld_lib_path_check_fwk = "/lib64" self.ld_lib_path_check_addons = "/add-ons" self.ascend_opp_path_check = "/op" self.v = "" def check_env(self, e): """check system env""" self._check_env() raise e def set_env(self): """set env: LD_LIBRARY_PATH, PATH, ASCEND_OPP_PATH""" if not self.tbe_path: self._check_env() return if Path(self.tbe_path).is_dir(): if os.getenv('LD_LIBRARY_PATH'): os.environ['LD_LIBRARY_PATH'] = self.tbe_path + ":" + os.environ['LD_LIBRARY_PATH'] else: os.environ['LD_LIBRARY_PATH'] = self.tbe_path else: logger.warning(f"No such directory: {self.tbe_path}, Please check if Ascend 910 AI software package is " f"installed correctly.") if Path(self.op_impl_path).is_dir(): # python path for sub process if os.getenv('PYTHONPATH'): os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH'] else: os.environ['PYTHONPATH'] = self.op_impl_path # sys path for this process sys.path.append(self.op_impl_path) os.environ['TBE_IMPL_PATH'] = self.op_impl_path else: logger.warning( f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data " "Center Solution) is installed correctly.") return if Path(self.cce_path).is_dir(): os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH'] else: logger.warning( f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") return if self.op_path is None: pass elif Path(self.op_path).is_dir(): os.environ['ASCEND_OPP_PATH'] = self.op_path else: logger.warning( f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") return if self.aicpu_path is None: pass elif Path(self.aicpu_path).is_dir(): os.environ['ASCEND_AICPU_PATH'] = self.aicpu_path else: logger.warning( f"No such directory: {self.aicpu_path}, Please check if Ascend AI software package (Ascend Data Center" " Solution) is installed correctly.") return def try_set_env_lib(self): """try set env but with no warning: LD_LIBRARY_PATH""" if Path(self.tbe_path).is_dir(): if os.getenv('LD_LIBRARY_PATH'): os.environ['LD_LIBRARY_PATH'] = self.tbe_path + ":" + os.environ['LD_LIBRARY_PATH'] else: os.environ['LD_LIBRARY_PATH'] = self.tbe_path def _check_env(self): """ascend dependence path check""" if self.path is None or self.path_check not in self.path: logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env " "PATH, you can reference to the installation guidelines https://www.mindspore.cn/install") if self.python_path is None or self.python_path_check not in self.python_path: logger.warning( "Can not find tbe op implement(need by mindspore-ascend), please check if you have set env " "PYTHONPATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and self.ld_lib_path_check_addons in self.ld_lib_path): logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env " "LD_LIBRARY_PATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path: logger.warning( "Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, " "you can reference to the installation guidelines https://www.mindspore.cn/install") class GPUEnvChecker(): """GPU environment check.""" def __init__(self): self.version = ["10.1"] # env self.path = os.getenv("PATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") # check self.v = "0" self.cuda_lib_path = self._get_lib_path("libcu") self.cuda_bin_path = self._get_bin_path("cuda") def _get_bin_path(self, bin_name): """Get bin path by bin name.""" if bin_name == "cuda": return self._get_cuda_bin_path() return [] def _get_cuda_bin_path(self): """Get cuda bin path by lib path.""" path_list = [] for path in self.cuda_lib_path: path = os.path.abspath(path.strip() + "/bin/") if Path(path).is_dir(): path_list.append(path) return list(set(path_list)) def _get_nvcc_version(self, is_set_env): """Get cuda version by nvcc command.""" nvcc_result = subprocess.run(["nvcc --version | grep release"], timeout=3, text=True, capture_output=True, check=False, shell=True) if nvcc_result.returncode: if not is_set_env: for path in self.cuda_bin_path: if Path(path + "/nvcc").is_file(): os.environ['PATH'] = path + ":" + os.environ['PATH'] return self._get_nvcc_version(True) return "" result = nvcc_result.stdout for line in result.split('\n'): if line: return line.strip().split("release")[1].split(",")[0].strip() return "" def check_env(self): """Check cuda version.""" version_match = False for path in self.cuda_lib_path: version_file = path + "/version.txt" if not Path(version_file).is_file(): continue if self._check_version(version_file): version_match = True break if not version_match: if self.v == "0": logger.warning("Cuda version file version.txt is not found, please confirm that the correct " "cuda version has been installed, you can refer to the " "installation guidelines: https://www.mindspore.cn/install") else: logger.warning(f"MindSpore version and cuda version {self.v} does not match, " "please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") nvcc_version = self._get_nvcc_version(False) if nvcc_version and (nvcc_version not in self.version): logger.warning(f"MindSpore version and nvcc(cuda bin) version {nvcc_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") def _check_version(self, version_file): """Check cuda version by version.txt.""" v = self._read_version(version_file) v = version.parse(v) v_str = str(v.major) + "." + str(v.minor) if v_str not in self.version: return False return True def _get_lib_path(self, lib_name): """Get gpu lib path by ldd command.""" path_list = [] current_path = os.path.split(os.path.realpath(__file__))[0] mindspore_path = os.path.dirname(os.path.dirname(current_path)) + "/mindspore" ldd_result = subprocess.run(["ldd " + mindspore_path + "/_c_expression*.so* | grep " + lib_name], timeout=3, text=True, capture_output=True, check=False, shell=True) if ldd_result.returncode: logger.warning(f"{lib_name} so(need by mndspore-gpu) is not found, please confirm that " f"_c_experssion.so depend on {lib_name}, " f"and _c_expression.so in directory:{mindspore_path}") return path_list result = ldd_result.stdout for i in result.split('\n'): path = i.partition("=>")[2] if path.lower().find("not found") > 0: logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm " "that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the " "installation guidelines: https://www.mindspore.cn/install") continue path = path.partition(lib_name)[0] if path: path_list.append(os.path.abspath(path.strip() + "../")) return list(set(path_list)) def _read_version(self, file_path): """Get gpu version info in version.txt.""" with open(file_path, 'r') as f: all_info = f.readlines() for line in all_info: if line.startswith("CUDA Version"): self.v = line.strip().split("CUDA Version")[1] return self.v return self.v def check_version_and_env_config(device_type): """check version and env config""" if device_type == "Ascend": env_checker = AscendEnvChecker() try: env_checker.set_env() except ImportError as e: env_checker.check_env(e) elif device_type == "Gpu": env_checker = GPUEnvChecker() env_checker.check_env() elif device_type == "Cpu": pass def check_version_and_try_set_env_lib(): """check version and try set env LD_LIBRARY_PATH""" env_checker = AscendEnvChecker() env_checker.try_set_env_lib() ================================================ FILE: mindspore_serving/server/worker/distributed/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """MindSpore Serving Distributed Worker.""" from .agent_startup import startup_agents from .register import declare_servable from .distributed_worker import start_servable ================================================ FILE: mindspore_serving/server/worker/distributed/agent_startup.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Serving, distributed worker agent startup""" import os import time import sys import traceback import signal from multiprocessing import Process, Pipe import threading import psutil from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ from mindspore_serving._mindspore_serving import DistributedServableConfig_, OneRankConfig_ from mindspore_serving import log as logger from mindspore_serving.server.common import check_type from mindspore_serving.server.worker.distributed import worker_agent def _get_local_ip(rank_list, port): """Get the local ip from the rank table config""" import socket ip_list = set() for item in rank_list: ip_list.add(item.ip) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) for ip in ip_list: try: s.bind((ip, port)) logger.info(f"Get local machine ip success, ip {ip}") return ip # pylint: disable=bare-except except: pass raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") def _check_local_ip(agent_ip, port): """Check the local ip""" import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) for i in range(8): try: s.bind((agent_ip, port + i)) logger.info(f"Check local machine ip success, ip {agent_ip}") return True # pylint: disable=bare-except except: pass return False def _check_model_files(num, files, model_files, group_config_files): """Check the number of model files or group config files""" if isinstance(files, tuple): for item in files: if isinstance(item, list): if num == -1: num = len(item) else: if num != len(item): raise RuntimeError(f"please check the number of model files and group config files, " f"model files: {model_files}, group config files: {group_config_files}") else: if num not in (-1, 1): raise RuntimeError(f"please check the number of model files and group config files, " f"model files: {model_files}, group config files: {group_config_files}") num = 1 return num def _check_model_num(model_files, group_config_files): """Check the number of model files or group config files""" num = _check_model_files(-1, model_files, model_files, group_config_files) if group_config_files is not None: num = _check_model_files(-1, group_config_files, model_files, group_config_files) if num != 1: raise RuntimeError(f"please check the number of group config files, currently only support one at most") def _update_model_files_path(model_files, group_config_files): """Check and return model files or group config files""" script_dir = os.path.dirname(os.path.realpath(sys.argv[0])) logger.info(f"input model files: {model_files}") logger.info(f"input group config files: {group_config_files}") model_files_temp = [] for item in model_files: if isinstance(item, list): inner_files = [] for inner in item: file_name = os.path.realpath(os.path.join(script_dir, inner)) if not os.access(file_name, os.R_OK): raise RuntimeError(f"Cannot access model file '{file_name}'") inner_files.append(file_name) model_files_temp.append(inner_files) else: file_name = os.path.realpath(os.path.join(script_dir, item)) if not os.access(file_name, os.R_OK): raise RuntimeError(f"Cannot access model file '{file_name}'") model_files_temp.append(file_name) if group_config_files is not None: group_files_temp = [] for item in group_config_files: if isinstance(item, list): inner_files = [] for inner in item: file_name = os.path.realpath(os.path.join(script_dir, inner)) if not os.access(file_name, os.R_OK): raise RuntimeError(f"Cannot access group config file '{file_name}'") inner_files.append(file_name) group_files_temp.append(inner_files) else: file_name = os.path.realpath(os.path.join(script_dir, item)) if not os.access(file_name, os.R_OK): raise RuntimeError(f"Cannot access group config file '{file_name}'") group_files_temp.append(file_name) else: group_files_temp = None logger.info(f"absolute model files: {model_files_temp}") logger.info(f"absolute group config files: {group_files_temp}") return model_files_temp, group_files_temp def _make_json_table_file(distributed_config): """Make rank table json file""" rank_size = len(distributed_config.rank_list) runtime_dir = os.path.abspath(".") time_stamp = str(time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))) rank_table_dir = os.path.join(runtime_dir, "temp_rank_table") try: os.mkdir(rank_table_dir) except FileExistsError: pass rank_table_file_name = os.path.join(rank_table_dir, f"hccl_rank_table_{time_stamp}_{rank_size}p.json") with open(rank_table_file_name, "w") as fp: fp.write(distributed_config.rank_table_content) return rank_table_file_name signal_success = "Success" signal_exit = "Exit" def _recv_parent(parent_process, index, recv_pipe, handle_stop_signal=True): """Receive message from Start up process. Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. Return True on receiving signal_success. """ try: while True: while not recv_pipe.poll(0.1): if handle_stop_signal and ExitSignalHandle_.has_stopped(): logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") return False if not parent_process.is_running(): # 3s logger.warning(f"Child {index}: Exit on failure of exit of parent process") return False parent_signal = recv_pipe.recv() break if parent_signal == signal_success: logger.info(f"Child {index}: Receive success") return True if parent_signal == signal_exit: logger.warning(f"Child {index}: Exit on receiving exit message") # pylint: disable=broad-except except Exception as e: logger.warning(f"Child {index}: Exit on exception: {e}") return False def _agent_process(send_pipe, recv_pipe, index, start_config, dec_key, dec_mode): """Agent process""" parent_process = psutil.Process(os.getppid()) try: # listening success or failed message from parent process worker_agent.start_worker_agent(start_config=start_config, dec_key=dec_key, dec_mode=dec_mode) send_pipe.send((index, signal_success)) success_msg = _recv_parent(parent_process, index, recv_pipe) if not success_msg: worker_agent.stop() send_pipe.close() recv_pipe.close() while not ExitSignalHandle_.has_stopped(): if not parent_process.is_running(): logger.warning(f"Child {index}, detect parent pid={parent_process.pid} has exited, child begin to exit") worker_agent.stop() return time.sleep(0.1) # pylint: disable=broad-except except Exception as e: traceback.print_exc() logger.error(f"Child {index}: Catch exception and notify exit of others") exception = RuntimeError(f"Child {index} exception happen: {e}") send_pipe.send((index, exception)) _recv_parent(parent_process, index, recv_pipe, False) logger.error(f"Child {index}: end send message to parent") def _send_pipe_msg(send_pipe, msg): """Send pipe message""" try: send_pipe.send(msg) # pylint: disable=broad-except except Exception as e: logger.warning(f"Send pipe message exception happen: {e}") def _send_exit_signal_to_children(subprocess_list): """Send exit signal to all child processes, and terminate all child processes when they are still alive in some seconds later""" def wait_exit(wait_seconds, msg): for i in range(wait_seconds): all_exit = True for process in subprocess_list: if process.is_alive(): logger.warning(f"There are still child processes that have not exited and {msg} in " f"{wait_seconds - i} seconds.") time.sleep(1) all_exit = False break if all_exit: logger.info(f"All Child process exited") return True return False if wait_exit(3, "SIGINT will be sent"): return # Send signal SIGINT for index, process in enumerate(subprocess_list): if process.is_alive(): logger.warning(f"Send signal SIGINT to {index}") try: child_process = psutil.Process(process.pid) children_of_child = child_process.children(recursive=True) for item in children_of_child: os.kill(item.pid, signal.SIGINT) # pylint: disable=broad-except except Exception as e: logger.warning(f"Get exception when send signal SIGINT to children of child {index}, exception: {e}") os.kill(process.pid, signal.SIGINT) if wait_exit(10, "will be forcibly killed"): return for index, process in enumerate(subprocess_list): if process.is_alive(): logger.warning(f"Kill Child process {index}") try: child_process = psutil.Process(process.pid) children_of_child = child_process.children(recursive=True) for item in children_of_child: os.kill(item.pid, signal.SIGKILL) # pylint: disable=broad-except except Exception as e: logger.warning(f"Get exception when send signal SIGKILL to children of child {index}, exception: {e}") os.kill(process.pid, signal.SIGKILL) def _send_exit_msg_to_children(send_pipe_list, subprocess_list): """Send exit msg to all child processes, and terminate all child processes when they are still alive in some seconds later. """ index = 0 for send_pipe, process in zip(send_pipe_list, subprocess_list): if process.is_alive(): logger.warning(f"Send exit message to Child {index}") _send_pipe_msg(send_pipe, signal_exit) logger.warning(f"End send exit message to Child {index}") else: logger.warning(f"Child {index} is not alive") index += 1 _send_exit_signal_to_children(subprocess_list) def _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list): """Listening child process""" count = len(send_pipe_list) for _ in range(count): while True: if p_recv_pipe.poll(0.1): break if ExitSignalHandle_.has_stopped(): logger.warning("Fail to start agents because of Ctrl+C") _send_exit_msg_to_children(send_pipe_list, subprocess_list) raise RuntimeError("Fail to start agents because of Ctrl+C") for send_pipe, process in zip(send_pipe_list, subprocess_list): if process.is_alive(): continue logger.warning("Fail to start agents because of death of one agent") _send_exit_msg_to_children(send_pipe_list, subprocess_list) raise RuntimeError("Fail to start agents because of death of one agent") index, msg = p_recv_pipe.recv() logger.info(f"Receive msg from Child {index}: {msg}") if isinstance(msg, Exception): logger.warning("Fail to start agents because of exception raise by one agent") _send_exit_msg_to_children(send_pipe_list, subprocess_list) raise msg for send_pipe in send_pipe_list: _send_pipe_msg(send_pipe, signal_success) def _listening_agents_after_startup(subprocess_list, distributed_address, agent_ip): """Listening agent status after success start up of agents""" def wait_child_exit(): while not ExitSignalHandle_.has_stopped(): for index, process in enumerate(subprocess_list): if not process.is_alive(): logger.warning(f"Child {index}, pid={process.pid} has exited") return time.sleep(0.1) def listening_thread_fun(): wait_child_exit() WorkerAgent_.startup_notify_exit(distributed_address, agent_ip) _send_exit_signal_to_children(subprocess_list) thread = threading.Thread(target=listening_thread_fun) thread.start() def _startup_agents(common_meta, distributed_address, agent_ip, agent_start_port, device_id_list, rank_id_list, model_files, group_config_files, rank_table_file, dec_key, dec_mode): """Start up all agents in one machine""" servable_name = common_meta.model_key send_pipe_list = [] subprocess_list = [] c_send_pipe, p_recv_pipe = Pipe() group_file = "" agents_count = len(device_id_list) for index in range(agents_count): device_id, rank_id, model_file = device_id_list[index], rank_id_list[index], model_files[index] if group_config_files is not None: group_file = group_config_files[index] p_send_pipe, c_recv_pipe = Pipe() send_pipe_list.append(p_send_pipe) agent_port = agent_start_port + index start_config = AgentStartUpConfig_() start_config.rank_id = rank_id start_config.device_id = device_id start_config.model_file_names = model_file if group_config_files is not None: start_config.group_file_names = group_file start_config.rank_table_json_file_name = rank_table_file start_config.agent_address = agent_ip + ":" + str(agent_port) start_config.distributed_address = distributed_address start_config.common_meta = common_meta process = Process(target=_agent_process, args=(c_send_pipe, c_recv_pipe, index, start_config, dec_key, dec_mode), name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") process.start() subprocess_list.append(process) msg = f"distributed worker_address: {distributed_address}, agent_ip: {agent_ip}, " \ f"agent_start_port: {agent_start_port}, device ids: {device_id_list}, rank ids: {rank_id_list}, " \ f"rank table file: {rank_table_file}, model files: {model_files}, group config files: {group_config_files}" try: _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list) # pylint: disable=broad-except except Exception as e: WorkerAgent_.notify_failed(distributed_address) logger.error(f"Failed to start agents, {msg}") print(f"Failed to start agents, {msg}") raise e logger.info(f"Success to start agents, {msg}") print(f"Success to start agents, {msg}") _listening_agents_after_startup(subprocess_list, distributed_address, agent_ip) class DistributedServableConfig: """Python DistributedServableConfig""" def __init__(self): self.rank_table_content = "" self.rank_list = None self.common_meta = None self.distributed_meta = None def set(self, config): """Set from C++ DistributedServableConfig_ obj""" self.rank_table_content = config.rank_table_content self.rank_list = [] for item in config.rank_list: new_item = {"device_id": item.device_id, "ip": item.ip} self.rank_list.append(new_item) self.common_meta = {"model_key": config.common_meta.model_key, "with_batch_dim": config.common_meta.with_batch_dim, "without_batch_dim_inputs": config.common_meta.without_batch_dim_inputs, "inputs_count": config.common_meta.inputs_count, "outputs_count": config.common_meta.outputs_count} self.distributed_meta = {"rank_size": config.distributed_meta.rank_size, "stage_size": config.distributed_meta.stage_size} def get(self): """Get as C++ DistributedServableConfig_ obj""" config = DistributedServableConfig_() config.rank_table_content = self.rank_table_content rank_list = [] for item in self.rank_list: new_item = OneRankConfig_() new_item.device_id = item["device_id"] new_item.ip = item["ip"] rank_list.append(new_item) config.rank_list = rank_list config.common_meta.model_key = self.common_meta["model_key"] config.common_meta.with_batch_dim = self.common_meta["with_batch_dim"] config.common_meta.without_batch_dim_inputs = self.common_meta["without_batch_dim_inputs"] config.common_meta.inputs_count = self.common_meta["inputs_count"] config.common_meta.outputs_count = self.common_meta["outputs_count"] config.distributed_meta.rank_size = self.distributed_meta["rank_size"] config.distributed_meta.stage_size = self.distributed_meta["stage_size"] return config def _get_worker_distributed_config(distributed_address): """Get worker distributed config from worker through sub process""" c_send_pipe, p_recv_pipe = Pipe() def process_fun(c_send_pipe): try: distributed_config = WorkerAgent_.get_agents_config_from_worker(distributed_address) config = DistributedServableConfig() config.set(distributed_config) c_send_pipe.send(config) # pylint: disable=broad-except except Exception as e: c_send_pipe.send(e) process = Process(target=process_fun, args=(c_send_pipe,), name=f"worker_agent_get_agents_config_from_worker") process.start() process.join() assert not process.is_alive() if p_recv_pipe.poll(0.1): config = p_recv_pipe.recv() if isinstance(config, Exception): raise config distributed_config = config.get() return distributed_config raise RuntimeError(f"Failed to get agents config from worker") def startup_agents(distributed_address, model_files, group_config_files=None, agent_start_port=7000, agent_ip=None, rank_start=None, dec_key=None, dec_mode='AES-GCM'): r""" Start all required worker agents on the current machine. These worker agent processes are responsible for inference tasks on the local machine. For details, please refer to `MindSpore Serving-based Distributed Inference Service Deployment `_. Args: distributed_address (str): The distributed worker address the agents linked to. model_files (Union[list[str], tuple[str]]): All model files need in current machine, absolute path or path relative to this startup python script. group_config_files (Union[list[str], tuple[str]], optional): All group config files need in current machine, absolute path or path relative to this startup python script, default ``None``, which means there are no configuration files. Default: ``None``. agent_start_port (int, optional): The starting agent port of the agents link to worker. Default: ``7000``. agent_ip (str, optional): The local agent ip, if it's ``None``, the agent ip will be obtained from rank table file. Default ``None``. Parameter `agent_ip` and parameter `rank_start` must have values at the same time, or both ``None`` at the same time. Default: ``None``. rank_start (int, optional): The starting rank id of this machine, if it's ``None``, the rank id will be obtained from rank table file. Default ``None``. Parameter `agent_ip` and parameter must have values at the same time, or both ``None`` at the same time. Default: ``None``. dec_key (bytes, optional): Byte type key used for decryption. The valid length is 16, 24, or 32. Default: ``None``. dec_mode (str, optional): Specifies the decryption mode, take effect when `dec_key` is set. Option: ``'AES-GCM'`` or ``'AES-CBC'``. Default: ``'AES-GCM'``. Raises: RuntimeError: Failed to start agents. Examples: >>> import os >>> from mindspore_serving.server import distributed >>> model_files = [] >>> for i in range(8): >>> model_files.append(f"models/device{i}/matmul.mindir") >>> distributed.startup_agents(distributed_address="127.0.0.1:6200", model_files=model_files) """ check_type.check_str("distributed_address", distributed_address) check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) model_files = check_type.check_and_as_tuple_with_str_list("model_files", model_files) if group_config_files is not None: group_config_files = check_type.check_and_as_tuple_with_str_list("group_config_files", group_config_files) # check dec_key and dec_mode if dec_key is not None: if not isinstance(dec_key, bytes): raise RuntimeError(f"Parameter 'dec_key' should be bytes, but actually {type(dec_key)}") if not dec_key: raise RuntimeError(f"Parameter 'dec_key' should not be empty bytes") if len(dec_key) not in (16, 24, 32): raise RuntimeError(f"Parameter 'dec_key' length {len(dec_key)} expected to be 16, 24 or 32") check_type.check_str("dec_mode", dec_mode) if dec_mode not in ('AES-GCM', 'AES-CBC'): raise RuntimeError(f"Parameter 'dec_mode' expected to be 'AES-GCM' or 'AES-CBC'") ExitSignalHandle_.start() distributed_config = _get_worker_distributed_config(distributed_address) # get machine ip rank_list = distributed_config.rank_list local_device_id_list = [] local_rank_id_list = [] if agent_ip is None: if rank_start is not None: raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " "or both None at the same time.") local_ip = _get_local_ip(rank_list, agent_start_port) # get all device_id and rank_id for rank_id, item in enumerate(rank_list): if item.ip == local_ip: local_device_id_list.append(item.device_id) local_rank_id_list.append(rank_id) else: if rank_start is None: raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " "or both None at the same time.") check_type.check_str("agent_ip", agent_ip) check_type.check_int("rank_start", rank_start, 0) if rank_start >= len(rank_list): raise RuntimeError(f"Parameter 'rank_start' cannot equal or larger than rank size {len(rank_list)}.") if not _check_local_ip(agent_ip, agent_start_port): raise RuntimeError(f"Check ip 'agent_ip' valid failed, agent_ip: {agent_ip}") local_ip = agent_ip rank_table_ip = rank_list[rank_start].ip for rank_id, item in enumerate(rank_list): if item.ip == rank_table_ip: local_device_id_list.append(item.device_id) local_rank_id_list.append(rank_id) # handle model files and group config files if len(local_device_id_list) != len(model_files): raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " f"{len(model_files)}, model files: {model_files}") if group_config_files is not None and len(model_files) != len(group_config_files): raise RuntimeError(f"Model files count {len(model_files)} does not equal to group config files " f"count {len(group_config_files)} when group_config_files is not None, " f"model files: {model_files}, group config files: {group_config_files}") _check_model_num(model_files, group_config_files) model_files, group_config_files = _update_model_files_path(model_files, group_config_files) # make json table file and export env rank_table_file = _make_json_table_file(distributed_config) _startup_agents(distributed_config.common_meta, distributed_address, local_ip, agent_start_port, local_device_id_list, local_rank_id_list, model_files, group_config_files, rank_table_file, dec_key, dec_mode) ================================================ FILE: mindspore_serving/server/worker/distributed/distributed_worker.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Serving, distributed worker startup""" from mindspore_serving._mindspore_serving import Worker_ from mindspore_serving.server.common import check_type from mindspore_serving.server.worker._worker import _start_py_task from mindspore_serving.server.worker._worker import stop_on_except, _load_servable_config @stop_on_except def start_servable(servable_directory, servable_name, rank_table_json_file, version_number, distributed_address, wait_agents_time_in_seconds, master_address, worker_address): r""" Start up the servable named 'servable_name' defined in 'servable_directory'. """ check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 1) check_type.check_str('rank_table_json_file', rank_table_json_file) check_type.check_str('distributed_address', distributed_address) _load_servable_config(servable_directory, servable_name) Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, distributed_address, master_address, worker_address, wait_agents_time_in_seconds) _start_py_task() ================================================ FILE: mindspore_serving/server/worker/distributed/register.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Serving, distributed worker register""" from mindspore_serving import log as logger from mindspore_serving.server.common import check_type from mindspore_serving.server.register.utils import get_servable_dir from mindspore_serving.server.register.model import append_declared_model from mindspore_serving._mindspore_serving import ModelMeta_, ServableRegister_ def declare_servable(rank_size, stage_size, with_batch_dim=True, without_batch_dim_inputs=None, enable_pipeline_infer=False): """declare distributed servable in servable_config.py. For details, please refer to `MindSpore Serving-based Distributed Inference Service Deployment `_. Args: rank_size (int): The rank size of the distributed model. stage_size (int): The stage size of the distributed model. with_batch_dim (bool, optional): Whether the first shape dim of the inputs and outputs of model is batch. Default: ``True``. without_batch_dim_inputs (Union[int, tuple[int], list[int]], optional): Index of inputs that without batch dim when `with_batch_dim` is ``True``. Default: ``None``. enable_pipeline_infer (bool, optional): Whether to enable pipeline parallel inference. Pipeline parallelism can effectively improve inference performance. For details, see `Pipeline Parallelism `_. Default: ``False``. Return: Model, identification of this model, can be used for `Model.call` or as the inputs of `add_stage`. Raises: RuntimeError: The type or value of the parameters are invalid. Examples: >>> from mindspore_serving.server import distributed >>> model = distributed.declare_servable(rank_size=8, stage_size=1) """ check_type.check_bool('with_batch_dim', with_batch_dim) check_type.check_bool('enable_pipeline_infer', enable_pipeline_infer) meta = ModelMeta_() meta.common_meta.servable_name = get_servable_dir() meta.common_meta.model_key = get_servable_dir() # used to identify model meta.common_meta.with_batch_dim = with_batch_dim if without_batch_dim_inputs: without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', without_batch_dim_inputs, 0) meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs # init distributed servable meta info check_type.check_int("rank_size", rank_size, 1) check_type.check_int("stage_size", stage_size, 1) meta.distributed_meta.rank_size = rank_size meta.distributed_meta.stage_size = stage_size meta.distributed_meta.enable_pipeline_infer = enable_pipeline_infer ServableRegister_.declare_distributed_model(meta) logger.info(f"Declare distributed servable, servable name: {meta.common_meta.model_key} " f", rank_size: {rank_size} , stage_size: {stage_size}, with_batch_dim: {with_batch_dim} " f", without_batch_dim_inputs: {without_batch_dim_inputs} " f", enable_pipeline_infer: {enable_pipeline_infer}") return append_declared_model(meta.common_meta.model_key) ================================================ FILE: mindspore_serving/server/worker/distributed/worker_agent.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Serving, distributed worker agent""" import os import threading from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_, ExitSignalHandle_ from mindspore_serving import log as logger from mindspore_serving.server.worker import init_mindspore def start_worker_agent(start_config, dec_key, dec_mode): """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents """ if not isinstance(start_config, AgentStartUpConfig_): raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") logger.info(f"rank_id={start_config.rank_id}, device_id={start_config.device_id}, " f"model_file='{start_config.model_file_names}', group_file='{start_config.group_file_names}', " f"rank_table_file='{start_config.rank_table_json_file_name}'," f"agent_address='{start_config.agent_address}', " f"distributed_address='{start_config.distributed_address}'" f"with_batch_dim={start_config.common_meta.with_batch_dim}, " f"without_batch_dim_inputs={start_config.common_meta.without_batch_dim_inputs}") ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message init_mindspore.init_mindspore_cxx_env(False) os.environ["RANK_ID"] = str(start_config.rank_id) os.environ["DEVICE_ID"] = str(start_config.device_id) os.environ["MS_ENABLE_HCCL"] = "1" if start_config.group_file_names: os.environ["PARA_GROUP_FILE"] = ';'.join(start_config.group_file_names) os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", "LD_LIBRARY_PATH", "PYTHONPATH"): logger.info(f"Env {item}: {os.getenv(item, None)}") if dec_key is None: dec_key = '' WorkerAgent_.start_agent(start_config, dec_key, dec_mode) start_wait_and_clear() _wait_and_clear_thread = None def start_wait_and_clear(): """Waiting for Ctrl+C, and clear up environment""" def thread_func(): logger.info("Serving worker Agent: wait for Ctrl+C to exit ------------------------------------") print("Serving worker Agent: wait for Ctrl+C to exit ------------------------------------") WorkerAgent_.wait_and_clear() logger.info("Serving worker Agent: exited ------------------------------------") print("Serving worker Agent: exited ------------------------------------") global _wait_and_clear_thread if not _wait_and_clear_thread: _wait_and_clear_thread = threading.Thread(target=thread_func) _wait_and_clear_thread.start() def stop(): r""" Stop the running of agent. """ WorkerAgent_.stop_and_clear() ================================================ FILE: mindspore_serving/server/worker/init_mindspore.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Init MindSpore Cxx""" import os import importlib.util from mindspore_serving import log as logger from mindspore_serving._mindspore_serving import Worker_ from .check_version import check_version_and_env_config, check_version_and_try_set_env_lib _flag_set_mindspore_cxx_env = False def get_mindspore_whl_path(): """Get MindSpore whl install path""" model_spec = importlib.util.find_spec("mindspore") if not model_spec or not model_spec.submodule_search_locations: return "" if not isinstance(model_spec.submodule_search_locations, list): return "" ms_dir = model_spec.submodule_search_locations[0] return ms_dir def check_mindspore_version(ms_dir): """check MindSpore version number""" try: from mindspore_serving.version import __version__ except ModuleNotFoundError: logger.warning(f"Get MindSpore Serving version failed") return try: with open(os.path.join(ms_dir, "version.py"), "r") as fp: version_str = fp.readline().replace("\n", "").replace("\r", "").replace(" ", "") \ .replace("'", "").replace("\"", "") prefix = "__version__=" if version_str[:len(prefix)] != prefix: logger.warning(f"Get MindSpore version failed") return ms_version = version_str[len(prefix):] except FileNotFoundError: logger.warning(f"Get MindSpore version failed") return serving_versions = __version__.split(".") ms_versions = ms_version.split(".") if serving_versions[:2] != ms_versions[:2]: logger.warning(f"MindSpore version {ms_version} and MindSpore Serving version {__version__} are expected " f"to be consistent. If not, there may be compatibility problems.") return def set_mindspore_cxx_env(): """Append MindSpore CXX lib path to LD_LIBRARY_PATH""" global _flag_set_mindspore_cxx_env if _flag_set_mindspore_cxx_env: return _flag_set_mindspore_cxx_env = True ld_lib_path = os.getenv('LD_LIBRARY_PATH', "") check_version_and_try_set_env_lib() # try set env LD_LIBRARY_PATH logger.info(f"Update env LD_LIBRARY_PATH from '{ld_lib_path}' to '{os.getenv('LD_LIBRARY_PATH')}'") ld_lib_path = os.getenv('LD_LIBRARY_PATH', "") ms_dir = get_mindspore_whl_path() if not ms_dir: logger.info(f"find mindspore failed, LD_LIBRARY_PATH will not add MindSpore lib path") return check_mindspore_version(ms_dir) ms_dir = os.path.join(ms_dir, "lib") if ld_lib_path: if ms_dir not in ld_lib_path.split(":"): os.environ['LD_LIBRARY_PATH'] = ld_lib_path + ":" + ms_dir else: os.environ['LD_LIBRARY_PATH'] = ms_dir logger.info(f"Update env LD_LIBRARY_PATH from '{ld_lib_path}' to '{os.getenv('LD_LIBRARY_PATH')}'") def init_mindspore_cxx_env(enable_lite): """Init env for load libmindspore.so""" set_mindspore_cxx_env() device_type = Worker_.get_device_type("none", enable_lite) if not device_type: logger.warning("Failed to get device type") return check_version_and_env_config(device_type) ================================================ FILE: mindspore_serving/server/worker/task.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Python run preprocess and postprocess in python""" import time import logging import numpy as np from mindspore_serving._mindspore_serving import Worker_ from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving.server.register.stage_function import stage_function_storage from mindspore_serving import log as logger class ServingSystemException(Exception): """Exception notify system error of worker, and need to exit py task""" def __init__(self, msg): super(ServingSystemException, self).__init__() self.msg = msg def __str__(self): return self.msg def has_worker_stopped(): """Whether worker has stopped""" return ExitSignalHandle_.has_stopped() class PyTaskHandler: """Handling preprocess and postprocess""" def run(self): """Run tasks of preprocess and postprocess, switch to other type of process when some instances are handled""" logger.info(f"start python task handling thread") while True: try: if has_worker_stopped(): logger.info("Worker has exited, exit python task handling thread") break task = Worker_.get_py_task() if task.has_stopped: logger.info("Worker has exited, exit python task handling thread") break self.run_inner(task) except Exception as e: # pylint: disable=broad-except logger.error(f"py task catch exception and exit: {e}") logging.exception(e) break logger.info("end python task handling thread") Worker_.stop_and_clear() @staticmethod def run_inner(task): """Iterator get result, and push it to c++""" task_name = task.task_name task_info = stage_function_storage.get(task_name) instance_list = task.instance_list # check input inputs_count = task_info["inputs_count"] for item in instance_list: if not isinstance(item, tuple) or len(item) != inputs_count: raise RuntimeError(f"The inputs number {len(item)} provided is not equal to the inputs number " f"{inputs_count} required by function {task_name}, stage index {task.stage_index}") instances_size = len(task.instance_list) index = 0 while index < instances_size: get_result_time_end = time.time() try: result = task_info["fun"](instance_list[index:]) # user-defined, may raise Exception if isinstance(result, (tuple, list)): # convert return result to yield result = iter(result) # pylint: disable=broad-except except Exception as e: logger.warning(f"{task_name} invoke catch exception: ") logging.exception(e) PyTaskHandler.push_failed(instances_size - index, str(e)) return # return will not terminate thread try: start_index = index for _ in range(index, instances_size): output = next(result) # user-defined, may raise Exception if not isinstance(output, (tuple, list)): output = (output,) # check output count if len(output) != task_info["outputs_count"]: error_msg = f"The outputs number {len(output)} of one instance returned by function " \ f"'{task_name}' is not equal to the outputs number {task_info['outputs_count']} " \ f" registered in method {task.method_name}" PyTaskHandler.push_system_failed(error_msg) raise ServingSystemException(error_msg) instance_result = [] for item in output: # convert MindSpore Tensor to numpy if callable(getattr(item, "asnumpy", None)): item = item.asnumpy() if isinstance(item, np.ndarray) and (not item.flags['FORC']): item = np.ascontiguousarray(item) instance_result.append(item) # raise ServingSystemException when user-defined output is invalid PyTaskHandler.push_result(instance_result) # push outputs of one instance index += 1 get_result_time = time.time() logger.info(f"method {task.method_name} stage {task.stage_index} function {task_name} get result " f"{start_index} ~ {instances_size - 1} cost time " f"{(get_result_time - get_result_time_end) * 1000} ms") except StopIteration: # raise by next error_msg = f"The number {index} of instances returned by function '{task_name}' is " \ f"not equal to the number {instances_size} of instances provided to this function." PyTaskHandler.push_system_failed(error_msg) raise RuntimeError(error_msg) except ServingSystemException as e: logger.error(f"{task_name} handling catch exception: {e}") PyTaskHandler.push_system_failed(e.msg) raise except Exception as e: # pylint: disable=broad-except # catch exception and try next logger.warning(f"{task_name} get result catch exception: {e}") logging.exception(e) PyTaskHandler.push_failed(1, str(e)) # push success results and a failed result index += 1 @staticmethod def push_failed(count, failed_msg): """Push failed result""" Worker_.push_pytask_failed(count, failed_msg) @staticmethod def push_system_failed(failed_msg): """Push failed result""" Worker_.push_pytask_system_failed(failed_msg) @staticmethod def push_result(instance_result): """Push success result""" try: Worker_.push_pytask_result(tuple(instance_result)) except Exception as e: raise ServingSystemException(f"Push py task result cause exception: {e}") def _start_py_task(): """Start python thread for python task""" if Worker_.enable_pytask_que(): PyTaskHandler().run() else: Worker_.wait_and_clear() ================================================ FILE: requirements_test.txt ================================================ numpy >= 1.17.0 protobuf >= 3.13.0 grpcio >= 1.36.0, <= 1.47.0 requests >= 2.22.0 psutil >= 5.9.1 ================================================ FILE: scripts/check_clang_format.sh ================================================ #!/bin/bash # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) version=$("${CLANG_FORMAT}" --version | sed -n "s/.*\ \([0-9]*\)\.[0-9]*\.[0-9]*.*/\1/p") if [[ "${version}" -lt "8" ]]; then echo "clang-format's version must be at least 8.0.0" exit 1 fi CURRENT_PATH=$(pwd) SCRIPTS_PATH=$(dirname "$0") echo "CURRENT_PATH=$CURRENT_PATH" echo "SCRIPTS_PATH=$SCRIPTS_PATH" # print usage message function usage() { echo "Check whether the specified source files were well formatted" echo "Usage:" echo "bash $0 [-a] [-c] [-l] [-h]" echo "e.g. $0 -a" echo "" echo "Options:" echo " -a Check code format of all files, default case" echo " -c Check code format of the files changed compared to last commit" echo " -l Check code format of the files changed in last commit" echo " -h Print usage" } # check and set options function checkopts() { # init variable mode="all" # default check all files # Process the options while getopts 'aclh' opt do case "${opt}" in a) mode="all" ;; c) mode="changed" ;; l) mode="lastcommit" ;; h) usage exit 0 ;; *) echo "Unknown option ${opt}!" usage exit 1 esac done } # init variable # check options checkopts "$@" # switch to project root path, which contains clang-format config file '.clang-format' cd "${SCRIPTS_PATH}/.." || exit 1 CHECK_LIST_FILE='__checked_files_list__' if [ "X${mode}" == "Xall" ]; then find mindspore_serving/ccsrc -type f -name "*" | grep "\.h$\|\.cc$\|\.c$" > "${CHECK_LIST_FILE}" || true elif [ "X${mode}" == "Xchanged" ]; then # --diff-filter=ACMRTUXB will ignore deleted files in commit git diff --diff-filter=ACMRTUXB --name-only | grep "mindspore_serving/ccsrc" | grep "\.h$\|\.cc$\|\.c$" > "${CHECK_LIST_FILE}" || true else # "X${mode}" == "Xlastcommit" git diff --diff-filter=ACMRTUXB --name-only HEAD~ HEAD | grep "mindspore_serving/ccsrc" | grep "\.h$\|\.cc$\|\.c$" > "${CHECK_LIST_FILE}" || true fi CHECK_RESULT_FILE=__code_format_check_result__ echo "0" > "$CHECK_RESULT_FILE" # check format of files modified in the latest commit while read line; do BASE_NAME=$(basename "${line}") TEMP_FILE="__TEMP__${BASE_NAME}" cp "${line}" "${TEMP_FILE}" ${CLANG_FORMAT} -i "${TEMP_FILE}" diff "${TEMP_FILE}" "${line}" ret=$? rm "${TEMP_FILE}" if [[ "${ret}" -ne 0 ]]; then echo "File ${line} is not formatted, please format it." echo "1" > "${CHECK_RESULT_FILE}" break fi done < "${CHECK_LIST_FILE}" result=$(cat "${CHECK_RESULT_FILE}") rm "${CHECK_RESULT_FILE}" rm "${CHECK_LIST_FILE}" cd "${CURRENT_PATH}" || exit 1 if [[ "X${result}" == "X0" ]]; then echo "Check PASS: specified files are well formatted!" fi exit "${result}" ================================================ FILE: scripts/format_source_code.sh ================================================ #!/bin/bash # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) version=$("${CLANG_FORMAT}" --version | sed -n "s/.*\ \([0-9]*\)\.[0-9]*\.[0-9]*.*/\1/p") if [[ "${version}" -lt "8" ]]; then echo "clang-format's version must be at least 8.0.0" exit 1 fi CURRENT_PATH=$(pwd) SCRIPTS_PATH=$(dirname "$0") echo "CURRENT_PATH=${CURRENT_PATH}" echo "SCRIPTS_PATH=${SCRIPTS_PATH}" # print usage message function usage() { echo "Format the specified source files to conform the code style." echo "Usage:" echo "bash $0 [-a] [-c] [-l] [-h]" echo "e.g. $0 -c" echo "" echo "Options:" echo " -a format of all files" echo " -c format of the files changed compared to last commit, default case" echo " -l format of the files changed in last commit" echo " -h Print usage" } # check and set options function checkopts() { # init variable mode="changed" # default format changed files # Process the options while getopts 'aclh' opt do case "${opt}" in a) mode="all" ;; c) mode="changed" ;; l) mode="lastcommit" ;; h) usage exit 0 ;; *) echo "Unknown option ${opt}!" usage exit 1 esac done } # init variable # check options checkopts "$@" # switch to project root path, which contains clang-format config file '.clang-format' cd "${SCRIPTS_PATH}/../.." || exit 1 FMT_FILE_LIST='__format_files_list__' if [[ "X${mode}" == "Xall" ]]; then find ./ -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true elif [[ "X${mode}" == "Xchanged" ]]; then git diff --name-only | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true else # "X${mode}" == "Xlastcommit" git diff --name-only HEAD~ HEAD | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true fi while read line; do if [ -f "${line}" ]; then ${CLANG_FORMAT} -i "${line}" fi done < "${FMT_FILE_LIST}" rm "${FMT_FILE_LIST}" cd "${CURRENT_PATH}" || exit 1 echo "Specified cpp source files have been format successfully." ================================================ FILE: setup.py ================================================ #!/usr/bin/env python3 # encoding: utf-8 # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """setup package.""" import os import stat import platform from setuptools import setup, find_packages from setuptools.command.egg_info import egg_info from setuptools.command.build_py import build_py version = '2.0.2' backend_policy = os.getenv('BACKEND_POLICY') commit_id = os.getenv('COMMIT_ID').replace("\n", "") package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "") pwd = os.path.dirname(os.path.realpath(__file__)) pkg_dir = os.path.join(pwd, 'build/package') def _read_file(filename): with open(os.path.join(pwd, filename), encoding='UTF-8') as f: return f.read() readme = _read_file('README.md') release = _read_file('RELEASE.md') def _write_version(file): file.write("__version__ = '{}'\n".format(version)) def _write_config(file): file.write("__backend__ = '{}'\n".format(backend_policy)) def _write_commit_file(file): file.write("__commit_id__ = '{}'\n".format(commit_id)) def _write_package_name(file): file.write("__package_name__ = '{}'\n".format(package_name)) def build_dependencies(): """generate python file""" version_file = os.path.join(pkg_dir, 'mindspore_serving', 'version.py') with open(version_file, 'w') as f: _write_version(f) version_file = os.path.join(pwd, 'mindspore_serving', 'version.py') with open(version_file, 'w') as f: _write_version(f) config_file = os.path.join(pkg_dir, 'mindspore_serving', 'default_config.py') with open(config_file, 'w') as f: _write_config(f) config_file = os.path.join(pwd, 'mindspore_serving', 'default_config.py') with open(config_file, 'w') as f: _write_config(f) package_info = os.path.join(pkg_dir, 'mindspore_serving', 'default_config.py') with open(package_info, 'a') as f: _write_package_name(f) package_info = os.path.join(pwd, 'mindspore_serving', 'default_config.py') with open(package_info, 'a') as f: _write_package_name(f) commit_file = os.path.join(pkg_dir, 'mindspore_serving', '.commit_id') with open(commit_file, 'w') as f: _write_commit_file(f) commit_file = os.path.join(pwd, 'mindspore_serving', '.commit_id') with open(commit_file, 'w') as f: _write_commit_file(f) build_dependencies() required_package = [ 'numpy >= 1.17.0', 'protobuf >= 3.13.0', 'grpcio >= 1.36.0, <= 1.47.0', 'psutil >= 5.9.1' ] package_data = { '': [ '*.so*', '*.pyd', '*.dll', 'lib/*.so*', 'lib/*.a', '.commit_id', '_mindspore_serving', 'proto/*.py' ] } def update_permissions(path): """ Update permissions. Args: path (str): Target directory path. """ if platform.system() == "Windows": return for dirpath, dirnames, filenames in os.walk(path): for dirname in dirnames: dir_fullpath = os.path.join(dirpath, dirname) os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP) for filename in filenames: file_fullpath = os.path.join(dirpath, filename) os.chmod(file_fullpath, stat.S_IREAD) def bin_files(): """ Gets the binary files to be installed. """ data_files = [] binary_files = [] cache_server_bin = os.path.join('mindspore_serving', 'bin', 'cache_server') if not os.path.exists(cache_server_bin): return data_files binary_files.append(cache_server_bin) cache_admin_bin = os.path.join('mindspore_serving', 'bin', 'cache_admin') if not os.path.exists(cache_admin_bin): return data_files binary_files.append(cache_admin_bin) data_files.append(('bin', binary_files)) return data_files class EggInfo(egg_info): """Egg info.""" def run(self): super().run() egg_info_dir = os.path.join(pkg_dir, 'mindspore_serving.egg-info') update_permissions(egg_info_dir) class BuildPy(build_py): """BuildPy.""" def run(self): super().run() mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore_serving') update_permissions(mindspore_dir) mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'akg') update_permissions(mindspore_dir) setup( name=package_name, version=version, author='The MindSpore Authors', author_email='contact@mindspore.cn', url='https://www.mindspore.cn', download_url='https://gitee.com/mindspore/serving/tags', project_urls={ 'Sources': 'https://gitee.com/mindspore/serving', 'Issue Tracker': 'https://gitee.com/mindspore/serving/issues', }, description='MindSpore is a new open source deep learning training/inference ' 'framework that could be used for mobile, edge and cloud scenarios.', # long_description="\n\n".join([readme, release]), long_description="\n\n".join([readme]), long_description_content_type="text/markdown", data_files=bin_files(), packages=find_packages(), package_data=package_data, include_package_data=True, cmdclass={ 'egg_info': EggInfo, 'build_py': BuildPy, }, python_requires='>=3.7', install_requires=required_package, classifiers=[ 'Development Status :: 4 - Beta', 'Environment :: Console', 'Intended Audience :: Science/Research', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: C++', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', keywords='mindspore machine learning', ) ================================================ FILE: tests/CMakeLists.txt ================================================ #add flags message("================START BUILD TESTCASES=================") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") add_subdirectory("ut") ================================================ FILE: tests/st/add/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/add/add.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit; CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} rm -rf serving *.log *.mindir *.dat kernel_meta rm -rf unix_socket_files serving_logs rm -rf add serving_client.py serving_client_with_check.py export_model serving_server.py cp -r ../../../example/tensor_add/* . clean_pid() { ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "clean pip failed" fi sleep 6 } prepare_model() { echo "### begin to generate mode for serving test ###" cd export_model python3 add_model.py &> add_model.log echo "### end to generate mode for serving test ###" result=`find . -name tensor_add.mindir | wc -l` if [ ${result} -ne 1 ] then cat add_model.log echo "### generate model for serving test failed ###" && exit 1 clean_pid cd - fi cd - } start_service() { echo "### start serving service ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "server failed to start." fi result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 count=$(($count+1)) result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_server.log echo "worker log begin----------------------------------" cat serving_logs/*.log echo "worker log end----------------------------------" echo "start serving service failed!" && exit 1 fi echo "### start serving service end ###" } pytest_serving() { unset http_proxy https_proxy echo "### client start ###" python3 serving_client_with_check.py > client.log 2>&1 if [ $? -ne 0 ] then clean_pid cat client.log echo "client failed to start." && exit 1 fi echo "### client end ###" } test_add_model() { start_service pytest_serving cat client.log clean_pid } echo "-----serving start-----" prepare_model test_add_model ================================================ FILE: tests/st/add/test_serving.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_add(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/add.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_serving_add() ================================================ FILE: tests/st/add_sub_pipeline/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/add_sub_pipeline/add_sub.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit; CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} rm -rf serving *.log *.mindir *.dat kernel_meta rm -rf unix_socket_files serving_logs rm -rf add serving_client.py export_model serving_server.py add_sub cp -r ../../../example/add_sub_pipeline/* . clean_pid() { ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "clean pip failed" fi sleep 6 } prepare_model() { echo "### begin to generate mode for serving test ###" cd export_model python3 add_sub_model.py &> add_sub_model.log echo "### end to generate mode for serving test ###" result=`find . -name tensor_add.mindir | wc -l` if [ ${result} -ne 1 ] then cat add_sub_model.log echo "### generate model for serving test failed ###" && exit 1 clean_pid cd - fi cd - } start_service() { echo "### start serving service ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "server failed to start." fi result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 count=$(($count+1)) result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_server.log echo "worker log begin----------------------------------" cat serving_logs/*.log echo "worker log end----------------------------------" echo "start serving service failed!" && exit 1 fi echo "### start serving service end ###" } pytest_serving() { unset http_proxy https_proxy echo "### client start ###" python3 serving_client.py > client.log 2>&1 if [ $? -ne 0 ] then clean_pid cat client.log echo "client failed to start." && exit 1 fi echo "### client end ###" } test_add_sub_pipeline() { start_service pytest_serving cat client.log clean_pid } echo "-----serving start-----" prepare_model test_add_sub_pipeline ================================================ FILE: tests/st/add_sub_pipeline/test_serving.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_add(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/add_sub.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_serving_add() ================================================ FILE: tests/st/distributed_server_fault/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/distributed_server_fault/common.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit; CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} get_serving_server_count() { num=`ps -ef | grep serving_server.py | grep -v grep | wc -l` return ${num} } get_serving_agent_count() { num=`ps -ef | grep serving_agent.py | grep -v grep | wc -l` return ${num} } clean_pid() { get_serving_server_count if [ $? -ne 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 fi count=0 get_serving_server_count while [[ $? -ne 0 && ${count} -lt 5 ]] do sleep 1 get_serving_server_count done get_serving_server_count if [ $? -ne 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 fi get_serving_agent_count if [ $? -ne 0 ] then ps aux | grep 'serving_agent.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 fi } prepare_model() { model_path=${CURRPATH}/../model if [ -d $model_path ] then echo "copy model path" cp -r ../model . else echo "### begin to generate mode for serving test ###" cd export_model || exit sh export_model.sh &> model.log echo "### end to generate mode for serving test ###" result=`find ../ -name model | wc -l` if [ ${result} -ne 1 ] then cat model.log clean_pid echo "### generate model for serving test failed ###" && exit 1 fi cd - || exit cp -r model ../ fi } start_serving_server() { echo "### start serving server ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "serving server failed to start." fi result=`grep -E 'Master server start success, listening on' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 get_serving_server_count if [ $? -eq 0 ] then clean_pid echo "serving server log begin-------------------" cat serving_server.log echo "serving server log end-------------------" echo "serving worker log begin-------------------" cat serving_logs/*.log echo "serving worker log end-------------------" echo "start serving server failed!" && exit 1 fi count=$(($count+1)) result=`grep -E 'Master server start success, listening on' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid echo "serving server log begin-------------------" cat serving_server.log echo "serving server log end-------------------" echo "serving worker log begin-------------------" cat serving_logs/*.log echo "serving worker log end-------------------" echo "start serving server failed!" && exit 1 fi echo "### start serving server end ###" } start_serving_agent() { echo "### start serving agent ###" unset http_proxy https_proxy python3 serving_agent.py > serving_agent.log 2>&1 & if [ $? -ne 0 ] then echo "server agent failed to start." fi result=`grep -E 'Child 0: Receive success' serving_agent.log | wc -l` count=0 while [[ ${result} -ne 1 && ${count} -lt 150 ]] do sleep 1 get_serving_agent_count if [ $? -eq 0 ] then clean_pid cat serving_agent.log echo "start serving agent failed!" && exit 1 fi count=$(($count+1)) result=`grep -E 'Child 0: Receive success' serving_agent.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_agent.log echo "start serving agent failed!" && exit 1 fi echo "### start serving agent end ###" } wait_server_exit() { get_serving_server_count count=0 while [[ $? -ne 0 && ${count} -lt 15 ]] do sleep 1 count=$(($count+1)) get_serving_server_count done if [ ${count} -eq 15 ] then echo "serving server exit failed" ps -ef | grep serving_server.py | grep -v grep echo "------------------------------ serving server failed log begin: " cat serving_server.log echo "------------------------------ serving server failed log end" clean_pid && exit 1 fi } wait_agent_exit() { get_serving_agent_count count=0 while [[ $? -ne 0 && ${count} -lt 15 ]] do sleep 1 count=$(($count+1)) get_serving_agent_count done if [ ${count} -eq 15 ] then echo "serving agent exit failed" ps -ef | grep serving_agent.py | grep -v grep echo "------------------------------ serving agent failed log begin: " cat serving_agent.log echo "------------------------------ serving agent failed log end" clean_pid && exit 1 fi } init() { rm -rf serving *.log *.mindir *.dat matmul kernel_meta rm -rf unix_socket_files serving_logs rm -rf *.json export_model serving_server.py serving_agent.py serving_client.py cp -r ../../../example/matmul_distributed/* . prepare_model } ================================================ FILE: tests/st/distributed_server_fault/kill_15_agent.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_serving_agent() { get_serving_server_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_serving_agent_count if [ $? -ne 9 ] then echo "serving agent start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'serving_agent.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 if [ $? -ne 0 ] then echo "kill agent failed" fi wait_agent_exit wait_server_exit } test_kill_serving_agent() { start_serving_server start_serving_agent kill_serving_agent clean_pid } echo "-----serving start-----" init test_kill_serving_agent echo "### end to serving test ###" ================================================ FILE: tests/st/distributed_server_fault/kill_15_server.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_serving_server() { get_serving_server_count if [ $? -ne 1 ] then echo "master_with_worker start failed" echo $? clean_pid && exit 1 fi get_serving_agent_count if [ $? -ne 9 ] then echo "agent start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 if [ $? -ne 0 ] then echo "kill master_with_worker failed" fi wait_agent_exit wait_server_exit } test_kill_serving_server() { start_serving_server start_serving_agent kill_serving_server clean_pid } echo "-----serving start-----" init test_kill_serving_server echo "### end to serving test ###" ================================================ FILE: tests/st/distributed_server_fault/kill_9_agent.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_serving_agent() { get_serving_server_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_serving_agent_count if [ $? -ne 9 ] then echo "serving agent start failed" echo $? clean_pid && exit 1 fi num=`grep -E 'Recv Pong Time Out from' serving_logs/log_matmul*.log | wc -l` if [ $num -ne 0 ] then echo "serving agent has exited" echo $num clean_pid && exit 1 fi ps aux | grep 'serving_agent.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "kill serving agent failed" fi sleep 25 get_serving_agent_count if [ $? -ne 0 ] then echo "agent exit failed" echo $? clean_pid && exit 1 fi get_serving_server_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi num=`grep -E 'Recv Pong Time Out from' serving_logs/log_matmul*.log | wc -l` if [ $num -ne 8 ] then echo "catch agent exit failed" echo $num clean_pid && exit 1 fi } test_kill_serving_agent() { start_serving_server start_serving_agent kill_serving_agent clean_pid } echo "-----serving start-----" init test_kill_serving_agent echo "### end to serving test ###" ================================================ FILE: tests/st/distributed_server_fault/kill_9_server.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_serving_server() { get_serving_server_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi num=`ps -ef | grep start_distributed_worker.py | grep -v grep | wc -l` if [ ${num} -ne 1 ] then echo "serving worker start failed" echo ${num} clean_pid && exit 1 fi get_serving_agent_count if [ $? -ne 9 ] then echo "serving agent start failed" echo $? clean_pid && exit 1 fi num=`grep -E 'Recv Ping Time Out from' serving_server.log | wc -l` if [ $num -ne 0 ] then echo "serving agent has exited" echo $num clean_pid && exit 1 fi ps aux | grep 'start_distributed_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "kill serving worker failed" fi sleep 25 num=`grep -E 'Recv Ping Time Out from' serving_agent.log | wc -l` if [ $num -ne 8 ] then echo "catch serving server exit failed" echo $num clean_pid && exit 1 fi } test_kill_serving_server() { start_serving_server start_serving_agent kill_serving_server clean_pid } echo "-----serving start-----" init test_kill_serving_server echo "### end to serving test ###" ================================================ FILE: tests/st/distributed_server_fault/test_distributed_fault.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_distribute_fault_kill_15_agent(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_15_agent.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_distribute_fault_kill_9_agent(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_9_agent.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_distribute_fault_kill_15_server(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_15_server.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_distribute_fault_kill_9_server(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_9_server.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_distribute_fault_kill_9_server() test_distribute_fault_kill_15_server() test_distribute_fault_kill_9_agent() test_distribute_fault_kill_15_agent() ================================================ FILE: tests/st/matmul_distributed/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/matmul_distributed/matmul_distribute.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} clean_server_pid() { num=`ps -ef | grep serving_server.py | grep -v grep | wc -l` if [ ${num} -ne 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 if [ $? -ne 0 ] then echo "clean master pid failed" fi fi num=`ps -ef | grep start_distributed_worker.py | grep -v grep | wc -l` count=0 while [[ ${num} -ne 0 && ${count} -lt 10 ]] do sleep 1 count=$(($count+1)) num=`ps -ef | grep start_distributed_worker.py | grep -v grep | wc -l` done if [ ${count} -eq 10 ] then echo "worker exit failed" echo $num ps -ef | grep start_distributed_worker.py | grep -v grep echo "------------------------------ worker failed master log begin: " cat serving_server.log echo "------------------------------ worker failed master log end" echo "------------------------------ worker failed log begin: " cat serving_logs/*.log echo "------------------------------ worker failed log end" clean_pid && exit 1 fi num=`ps -ef | grep serving_agent.py | grep -v grep | wc -l` count=0 while [[ ${num} -ne 0 && ${count} -lt 10 ]] do sleep 1 count=$(($count+1)) num=`ps -ef | grep serving_agent.py | grep -v grep | wc -l` done if [ ${count} -eq 10 ] then echo "agent exit failed" echo $num ps -ef | grep serving_agent.py | grep -v grep echo "------------------------------ agent failed log begin: " cat serving_agent.log echo "------------------------------ agent failed log end" clean_pid && exit 1 fi } clean_pid() { ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep if [ $? -eq 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 echo "### master pid exist, clean master pip failed ###" fi ps aux | grep 'start_distributed_worker.py' | grep ${CURRUSER} | grep -v grep if [ $? -eq 0 ] then ps aux | grep 'start_distributed_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 echo "### master pid is killed but worker pid exist ###" fi ps aux | grep 'serving_agent.py' | grep ${CURRUSER} | grep -v grep if [ $? -eq 0 ] then ps aux | grep 'serving_agent.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 echo "### worker pid is killed but agent pid exist ###" fi } prepare_model() { model_path=${CURRPATH}/../model if [ -d $model_path ] then echo "copy model path" cp -r ../model . else echo "### begin to generate mode for serving test ###" cd export_model || exit sh export_model.sh &> model.log echo "### end to generate mode for serving test ###" result=`find ../ -name model | wc -l` if [ ${result} -ne 1 ] then echo "### begin model generation log ###" cat model.log echo "### end model generation log ###" clean_pid echo "### generate model for serving test failed ###" && exit 1 fi cd - || exit cp -r model ../ fi } start_serving_server() { echo "### start serving server ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "serving server failed to start." fi result=`grep -E 'Master server start success, listening on' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 num=`ps -ef | grep serving_server.py | grep -v grep | wc -l` if [ ${num} -eq 0 ] then echo "serving server log begin-------------------" cat serving_server.log echo "serving server log end-------------------" echo "serving worker log begin-------------------" cat serving_logs/*.log echo "serving worker log end-------------------" clean_pid echo "start serving server failed!" && exit 1 fi count=$(($count+1)) result=`grep -E 'Master server start success, listening on' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then echo "serving server log begin-------------------" cat serving_server.log echo "serving server log end-------------------" echo "serving worker log begin-------------------" cat serving_logs/*.log echo "serving worker log end-------------------" clean_pid echo "start serving server failed!" && exit 1 fi echo "### start serving server end ###" } start_serving_agent() { echo "### start serving agent ###" unset http_proxy https_proxy python3 serving_agent.py > serving_agent.log 2>&1 & if [ $? -ne 0 ] then echo "server agent failed to start." fi result=`grep -E 'Child 0: Receive success' serving_agent.log | wc -l` count=0 while [[ ${result} -ne 1 && ${count} -lt 150 ]] do sleep 1 num=`ps -ef | grep serving_agent.py | grep -v grep | wc -l` if [ ${num} -eq 0 ] then clean_pid cat serving_agent.log echo "start serving agent failed!" && exit 1 fi count=$(($count+1)) result=`grep -E 'Child 0: Receive success' serving_agent.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_agent.log echo "start serving agent failed!" && exit 1 fi echo "### start serving agent end ###" } pytest_serving() { unset http_proxy https_proxy echo "### client start ###" python3 serving_client.py > serving_client.log 2>&1 if [ $? -ne 0 ] then cat serving_client.log clean_server_pid clean_pid echo "client failed to start." && exit 1 fi echo "### client end ###" } test_matmul_distribute() { start_serving_server start_serving_agent pytest_serving cat serving_client.log clean_server_pid clean_pid } echo "-----serving start-----" rm -rf serving *.log *.dat matmul model kernel_meta somas_meta rm -rf unix_socket_files serving_logs rm -rf serving_client.py export_model temp_rank_table serving_server.py serving_agent.py rank_table_8pcs.json cp -r ../../../example/matmul_distributed/* . prepare_model test_matmul_distribute ================================================ FILE: tests/st/matmul_distributed/test_matmul_distribute.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_matmul_distributed(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/matmul_distribute.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_serving_matmul_distributed() ================================================ FILE: tests/st/matmul_multi_subgraphs/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/matmul_multi_subgraphs/matmul_multi_subgraphs.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit; CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} rm -rf serving *.log *.mindir *.dat kernel_meta rm -rf unix_socket_files serving_logs rm -rf add serving_client.py export_model serving_server.py cp -r ../../../example/matmul_multi_subgraphs/* . clean_pid() { ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "clean pip failed" fi sleep 6 } prepare_model() { echo "### begin to generate mode for serving test ###" cd export_model python3 export_matmul.py &> export_matmul.log echo "### end to generate mode for serving test ###" result=`find . -name matmul_0.mindir | wc -l` if [ ${result} -ne 1 ] then cat export_matmul.log echo "### generate model for serving test failed ###" && exit 1 clean_pid cd - fi cd - } start_service() { echo "### start serving service ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "server failed to start." fi result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 count=$(($count+1)) result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_server.log echo "worker log begin----------------------------------" cat serving_logs/*.log echo "worker log end----------------------------------" echo "start serving service failed!" && exit 1 fi echo "### start serving service end ###" } pytest_serving() { unset http_proxy https_proxy echo "### client start ###" python3 serving_client.py > client.log 2>&1 if [ $? -ne 0 ] then clean_pid cat client.log echo "client failed to start." && exit 1 fi echo "### client end ###" } test_matmul_model() { start_service pytest_serving cat client.log clean_pid } echo "-----serving start-----" prepare_model test_matmul_model ================================================ FILE: tests/st/matmul_multi_subgraphs/test_matmul_multi_subgraphs.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_pipeline_distributed(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/matmul_multi_subgraphs.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_serving_pipeline_distributed() ================================================ FILE: tests/st/resnet/__init__.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/resnet/resnet.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} clean_pid() { ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "clean pip failed" fi sleep 6 } prepare_model() { echo "### begin to generate mode for serving test ###" cd export_model python3 export_resnet.py False &> export_resnet.log echo "### end to generate mode for serving test ###" result=`find . -name resnet50_1b_cifar10.mindir | wc -l` if [ ${result} -ne 1 ] then cat export_resnet.log echo "### generate model for serving test failed ###" && exit 1 clean_pid cd - fi cd - } start_service() { echo "### start serving service ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "server failed to start." fi result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 count=$(($count+1)) result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid cat serving_server.log echo "start serving service failed!" && exit 1 fi echo "### start serving service end ###" } pytest_serving() { unset http_proxy https_proxy echo "### client start ###" python3 serving_client.py > serving_client.log 2>&1 if [ $? -ne 0 ] then clean_pid cat serving_client.log echo "client failed to start." && exit 1 fi echo "### client end ###" } test_renet_model() { start_service pytest_serving cat serving_client.log clean_pid } echo "-----serving start-----" rm -rf serving *.log *.mindir *.dat kernel_meta rm -rf unix_socket_files serving_logs rm -rf serving_client.py export_model serving_server.py resnet50 test_image cp -r ../../../example/resnet/* . prepare_model test_renet_model echo "### end to serving test ###" ================================================ FILE: tests/st/resnet/test_resnet.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_resnet(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/resnet.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_resnet() ================================================ FILE: tests/st/serving_fault/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/st/serving_fault/common.sh ================================================ #!/bin/bash export GLOG_v=1 cd "$(dirname $0)" || exit CURRPATH=$(pwd) CURRUSER=$(whoami) PROJECT_PATH=${CURRPATH}/../../../ echo "CURRPATH:" ${CURRPATH} echo "CURRUSER:" ${CURRUSER} echo "PROJECT_PATH:" ${PROJECT_PATH} echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} echo "PYTHONPATH: " ${PYTHONPATH} clean_pid() { get_master_count if [ $? -ne 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 fi count=0 get_master_count while [[ $? -ne 0 && ${count} -lt 5 ]] do sleep 1 get_master_count done get_master_count if [ $? -ne 0 ] then ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 fi get_worker_count if [ $? -ne 0 ] then ps aux | grep 'start_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 fi } prepare_model() { echo "### begin to generate mode for serving test ###" cd export_model python3 add_model.py &> add_model.log echo "### end to generate mode for serving test ###" result=`find . -name tensor_add.mindir | wc -l` if [ ${result} -ne 1 ] then cat add_model.log echo "### generate model for serving test failed ###" && exit 1 clean_pid cd - fi cd - } start_serving_server() { echo "### start serving server ###" unset http_proxy https_proxy python3 serving_server.py > serving_server.log 2>&1 & if [ $? -ne 0 ] then echo "server server failed to start." fi result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` count=0 while [[ ${result} -eq 0 && ${count} -lt 150 ]] do sleep 1 get_master_count if [ $? -eq 0 ] then echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" echo "start serving server failed!" && exit 1 fi count=$(($count+1)) result=`grep -E 'Serving gRPC server start success, listening on 127.0.0.1:5500' serving_server.log | wc -l` done if [ ${count} -eq 150 ] then clean_pid echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" echo "start serving server failed!" && exit 1 fi echo "### start serving server end ###" } get_master_count() { num=`ps -ef | grep serving_server.py | grep -v grep | wc -l` return ${num} } get_worker_count() { num=`ps -ef | grep start_worker.py | grep -v grep | wc -l` return ${num} } wait_master_exit() { get_master_count count=0 while [[ $? -ne 0 && ${count} -lt 15 ]] do sleep 1 count=$(($count+1)) get_master_count done if [ ${count} -eq 15 ] then echo "serving master exit failed" ps -ef | grep serving_server.py | grep -v grep echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" clean_pid && exit 1 fi } wait_worker_exit() { get_worker_count count=0 while [[ $? -ne 0 && ${count} -lt 15 ]] do sleep 1 count=$(($count+1)) get_worker_count done if [ ${count} -eq 15 ] then echo "serving worker exit failed" ps -ef | grep start_worker.py | grep -v grep echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" clean_pid && exit 1 fi } init() { rm -rf serving *.log *.mindir *.dat kernel_meta rm -rf unix_socket_files serving_logs rm -rf add export_model serving_server.py serving_client.py serving_client_with_check.py cp -r ../../../example/tensor_add/* . prepare_model clean_pid } ================================================ FILE: tests/st/serving_fault/kill_15_master.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_master() { get_master_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_worker_count if [ $? -eq 0 ] then echo "worker start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 if [ $? -ne 0 ] then echo "kill master failed" fi wait_master_exit wait_worker_exit } test_master_fault_model() { start_serving_server kill_master clean_pid } echo "-----serving start-----" init test_master_fault_model echo "### end to serving test ###" ================================================ FILE: tests/st/serving_fault/kill_15_worker.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_worker() { get_master_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_worker_count if [ $? -eq 0 ] then echo "worker start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'start_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 if [ $? -ne 0 ] then echo "kill master failed" fi wait_master_exit wait_worker_exit } test_worker_fault_model() { start_serving_server kill_worker clean_pid } echo "-----serving start-----" init test_worker_fault_model echo "### end to serving test ###" ================================================ FILE: tests/st/serving_fault/kill_9_master.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_master() { get_master_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_worker_count if [ $? -eq 0 ] then echo "worker start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'serving_server.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "kill master failed" fi wait_worker_exit } test_master_fault_model() { start_serving_server kill_master clean_pid } echo "-----serving start-----" init test_master_fault_model echo "### end to serving test ###" ================================================ FILE: tests/st/serving_fault/kill_9_worker.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh kill_worker() { get_master_count if [ $? -ne 1 ] then echo "serving server start failed" echo $? clean_pid && exit 1 fi get_worker_count if [ $? -eq 0 ] then echo "worker start failed" echo $? clean_pid && exit 1 fi ps aux | grep 'start_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -9 if [ $? -ne 0 ] then echo "kill worker failed" fi wait_master_exit } test_worker_fault_model() { start_serving_server kill_worker clean_pid } echo "-----serving start-----" init test_worker_fault_model echo "### end to serving test ###" ================================================ FILE: tests/st/serving_fault/restart.sh ================================================ #!/bin/bash CURRPATH=$(cd "$(dirname $0)" || exit; pwd) source ${CURRPATH}/common.sh unset http_proxy https_proxy run_client() { echo "### client start ###" python3 serving_client_with_check.py > client.log 2>&1 if [ $? -ne 0 ] then clean_pid cat client.log echo "client failed to start." && exit 1 fi cat client.log echo "### client end ###" } listening_worker_restart() { start_count=$1 echo "### serving server worker restart begin ###" result=`grep -E 'Register success: worker address' serving_server.log | wc -l` count=0 while [[ ${result} -le $start_count && ${count} -lt 30 ]] do sleep 1 get_master_count if [ $? -eq 0 ] then echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" echo "serving server worker restart failed! start count $start_count" && exit 1 fi count=$(($count+1)) result=`grep -E 'Register success: worker address' serving_server.log | wc -l` done if [ ${count} -eq 30 ] then clean_pid echo "---------------------------------- server server log begin" cat serving_server.log echo "---------------------------------- server server log end" echo "---------------------------------- server worker log begin" cat serving_logs/*.log echo "---------------------------------- server worker log end" echo "serving server worker restart failed! start count $start_count" && exit 1 fi echo "### serving server worker restart end ###" } test_restart() { start_serving_server # shellcheck disable=SC2207 worker_pids=($(ps aux | grep 'start_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}')) if [ ${#worker_pids[*]} -ne 2 ]; then echo "worker process number is not 2, real count " ${#worker_pids[*]} ps -ef | grep start_worker.py clean_pid && exit 1 fi echo "before restart" ps -ef | grep 'start_worker.py' # test kill -9 and restart run_client echo "kill first worker " ${worker_pids[0]} kill -s 9 ${worker_pids[0]} echo "after first kill" ps -ef | grep 'start_worker.py' run_client listening_worker_restart 2 # current has 2 Register success log run_client echo "kill second worker " ${worker_pids[1]} kill -s 9 ${worker_pids[1]} echo "after second kill" ps -ef | grep 'start_worker.py' listening_worker_restart 3 # current has 3 Register success log # test kill -15 run_client # shellcheck disable=SC2207 worker_pids=($(ps aux | grep 'start_worker.py' | grep ${CURRUSER} | grep -v grep | awk '{print $2}')) if [ ${#worker_pids[*]} -ne 2 ]; then echo "restarted worker process number is not 2, real count " ${#worker_pids[*]} ps -ef | grep start_worker.py clean_pid && exit 1 fi echo "end restart" ps -ef | grep 'start_worker.py' kill -s 15 ${worker_pids[0]} kill -s 15 ${worker_pids[1]} wait_master_exit clean_pid } echo "-----serving start-----" init test_restart echo "-----serving end-----" ================================================ FILE: tests/st/serving_fault/test_serving_fault.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_fault_kill_15_master(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_15_master.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_fault_kill_9_master(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_9_master.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_fault_kill_15_worker(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_15_worker.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def test_serving_fault_kill_9_worker(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/kill_9_worker.sh") assert np.allclose(ret, 0) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.env_single def serving_fault_restart(): """test_serving""" sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/restart.sh") assert np.allclose(ret, 0) if __name__ == '__main__': test_serving_fault_kill_9_master() test_serving_fault_kill_15_master() test_serving_fault_kill_9_worker() test_serving_fault_kill_15_worker() ================================================ FILE: tests/ut/CMakeLists.txt ================================================ add_subdirectory(python) add_subdirectory(cpp) ================================================ FILE: tests/ut/coverage/cov_config ================================================ [run] omit = */__init__.py,*/*_pb2.py,*/*_pb2_grpc.py,*/tests/* ================================================ FILE: tests/ut/coverage/run_coverage.sh ================================================ #!/bin/bash # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e BASEPATH=$( cd "$(dirname "$0")" pwd ) PROJECT_PATH=$( cd ${BASEPATH}/../../.. pwd ) BUILD_PKG=${PROJECT_PATH}/build/package export PYTHONPATH=${BUILD_PKG}:${PROJECT_PATH}/tests/ut/python:$PYTHONPATH export LD_LIBRARY_PATH=${BUILD_PKG}/tests/mindspore/lib:${LD_LIBRARY_PATH} echo "PYTHONPATH=$PYTHONPATH" echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" export GLOG_v=1 unset http_proxy unset https_proxy rm -rf cov_output htmlcov .coverage # run python ut pytest -v ${PROJECT_PATH}/tests/ut/python/tests/ --cov=${BUILD_PKG}/mindspore_serving --cov-config=${BASEPATH}/cov_config --cov-report=html --cov-branch # run cpp ut bash ../cpp/runtest.sh mkdir cov_output && cd cov_output lcov --capture --directory ${PROJECT_PATH}/build/mindspore_serving/ --output-file coverage.info; lcov --extract coverage.info '*/ccsrc/*' -o coverage.info; genhtml coverage.info --output-directory ./ --sort --legend ================================================ FILE: tests/ut/cpp/CMakeLists.txt ================================================ # This branch assumes that gRPC and all its dependencies are already installed # on this system, so they can be located by find_package(). # Find Protobuf installation # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") # serving_common for c++ server and python interface file(GLOB_RECURSE UT_SERVING_CORE_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore_serving/ccsrc/common/*.cc" "../../../mindspore_serving/ccsrc/master/*.cc" "../../../mindspore_serving/ccsrc/worker/*.cc") file(GLOB_RECURSE UT_SERVING_RMV_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore_serving/ccsrc/worker/inference/inference.cc") list(REMOVE_ITEM UT_SERVING_CORE_SRC ${UT_SERVING_RMV_SRC}) file(GLOB_RECURSE UT_SERVING_STUB RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../stub/*.cc") set(UT_SERVING_COMMON ${UT_SERVING_CORE_SRC} ${UT_SERVING_STUB}) include_directories("${CMAKE_BINARY_DIR}/mindspore_serving" ${CMAKE_BINARY_DIR}) # for proto header file include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(../) include_directories(../stub) include_directories(../stub/include) include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../mindspore_serving/ccsrc) link_directories(${CMKAE_BINARY_DIR}/securec/src) # copy libevent lib file(GLOB_RECURSE LIBEVENT_LIB_LIST ${libevent_LIBPATH}/libevent* libevent_LIBPATH}/libevent_pthreads*) file(COPY ${LIBEVENT_LIB_LIST} DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) # copy glog lib file(GLOB_RECURSE GLOG_LIB_LIST ${glog_LIBPATH}/libmindspore_serving_glog*) file(COPY ${GLOG_LIB_LIST} DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) # copy grpc lib file(GLOB_RECURSE GPRC_LIB_LIST ${grpc_LIBPATH}/lib*) file(COPY ${GPRC_LIB_LIST} DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) # for cpp/serving_ut set(CPP_UT_SERVING_CORE ${UT_SERVING_COMMON} ${UT_SERVING_ASCEND}) add_library(cpp_serving_common STATIC ${CPP_UT_SERVING_CORE}) target_link_libraries(cpp_serving_common PRIVATE PROTO_SRC_LIB) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::ssl mindspore_serving::crypto) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::grpc++) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::protobuf pthread rt dl) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::event mindspore_serving::event_pthreads) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::event_openssl) target_link_libraries(cpp_serving_common PRIVATE pthread mindspore_serving::glog) target_link_libraries(cpp_serving_common PRIVATE mindspore_serving::eigen) target_link_libraries(cpp_serving_common PRIVATE ${SECUREC_LIBRARY}) # for test link_directories(${CMAKE_BINARY_DIR}/googletest/googlemock/gtest) file(GLOB_RECURSE UT_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" "tests/*.cc") add_executable(serving_ut ${UT_LIST}) target_link_libraries(serving_ut PRIVATE mindspore_serving::gtest) target_link_libraries(serving_ut PRIVATE -Wl,--whole-archive cpp_serving_common -Wl,--no-whole-archive) # disable auto rpath set_target_properties(serving_ut PROPERTIES SKIP_BUILD_RPATH TRUE) # copy gtest lib file(GLOB_RECURSE GTEST_LIB_LIST ${gtest_LIBPATH}/libgtest*) file(COPY ${GTEST_LIB_LIST} DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) ================================================ FILE: tests/ut/cpp/common/common_test.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/common_test.h" #define private public #include "mindspore_serving/ccsrc/common/servable.h" #undef private #include "mindspore_serving/ccsrc/worker/servable_register.h" #ifdef __cplusplus #if __cplusplus extern "C" { #endif #endif namespace UT { void Common::SetUpTestCase() {} void Common::TearDownTestCase() {} void Common::SetUp() {} void Common::TearDown() { mindspore::serving::ServableRegister::Instance() = mindspore::serving::ServableRegister(); } } // namespace UT #ifdef __cplusplus #if __cplusplus } #endif #endif ================================================ FILE: tests/ut/cpp/common/common_test.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TESTS_UT_COMMON_UT_COMMON_H_ #define TESTS_UT_COMMON_UT_COMMON_H_ #include #include #include #include "gtest/gtest.h" namespace UT { class Common : public testing::Test { public: // TestCase only enter once static void SetUpTestCase(); static void TearDownTestCase(); // every TEST_F macro will enter one virtual void SetUp(); virtual void TearDown(); }; } // namespace UT #endif // TESTS_UT_COMMON_UT_COMMON_H_ ================================================ FILE: tests/ut/cpp/common/test_main.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "gtest/gtest.h" GTEST_API_ int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); int ret = RUN_ALL_TESTS(); return ret; } ================================================ FILE: tests/ut/cpp/common/test_servable_common.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H #define MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H #include #include #include #include #include #include #include #include #include "common/common_test.h" #include "master/server.h" #define private public #include "worker/worker.h" #undef private #include "worker/notfiy_master/base_notify.h" #include "worker/context.h" #include "worker/local_servable/local_model_loader.h" #include "master/grpc/grpc_process.h" #include "mindspore_serving/proto/ms_service.pb.h" #include "mindspore_serving/ccsrc/worker/servable_register.h" namespace mindspore { namespace serving { #define ExpectContainMsg(error_msg, expected_msg) \ { \ std::string error_msg_str = error_msg; \ EXPECT_TRUE(error_msg_str.find(expected_msg) != std::string::npos); \ if (error_msg_str.find(expected_msg) == std::string::npos) { \ std::cout << "error_msg: " << error_msg_str << ", expected_msg: " << expected_msg << std::endl; \ } \ } class FakeNotifyMaster : public BaseNotifyMaster { public: Status Register(const WorkerRegSpec &worker_spec) override { return SUCCESS; } Status Unregister() override { return SUCCESS; } }; class TestMasterWorker : public UT::Common { public: TestMasterWorker() = default; void Init(std::string servable_dir, std::string servable_name, int version_number, std::string model_file) { servable_dir_ = servable_dir; servable_name_ = servable_name; version_number_ = version_number; model_file_ = model_file; servable_name_path_ = servable_dir_ + "/" + servable_name_; version_number_path_ = servable_name_path_ + "/" + std::to_string(version_number_); model_name_path_ = version_number_path_ + "/" + model_file_; __mode_t access_mode = S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH; mkdir(servable_dir_.c_str(), access_mode); mkdir(servable_name_path_.c_str(), access_mode); mkdir(version_number_path_.c_str(), access_mode); std::ofstream fp(model_name_path_); fp << "model content"; fp.close(); model_name_path_list_.emplace(model_name_path_); version_number_path_list_.emplace(version_number_path_); servable_name_path_list_.emplace(servable_name_path_); servable_dir_list_.emplace(servable_dir_); } virtual void SetUp() {} virtual void TearDown() { for (auto &item : model_name_path_list_) { remove(item.c_str()); } for (auto &item : version_number_path_list_) { rmdir(item.c_str()); } for (auto &item : servable_name_path_list_) { rmdir(item.c_str()); } for (auto &item : servable_dir_list_) { rmdir(item.c_str()); } Worker::GetInstance().Clear(); Server::Instance().Clear(); UT::Common::TearDown(); } void StartAddServable() { auto status = StartServable(servable_dir_, servable_name_, 1); ASSERT_TRUE(status.IsSuccess()); } void RegisterAddServable(bool with_batch_dim = false) { DeclareServable(servable_name_, model_file_, "mindir", with_batch_dim); // register_method RegisterMethod(servable_name_, model_file_, "add_common", {"x1", "x2"}, {"y"}, 2, 1); } static Status StartServable(const std::string &servable_dir, const std::string &servable_name, int version_number) { char path[PATH_MAX]; std::string current_path = getcwd(path, PATH_MAX); auto notify_master = std::make_shared(); ServableContext::Instance()->SetDeviceId(0); ServableContext::Instance()->SetDeviceTypeStr("Ascend"); auto servable_dir_full = current_path + "/" + servable_dir; const auto &signature = ServableRegister::Instance().GetServableSignature(); Status status; std::map> models_loader; for (auto &model_meta : signature.model_metas) { auto &model_key = model_meta.common_meta.model_key; auto local_models_loader = std::make_shared(); status = local_models_loader->LoadModel(servable_dir_full, servable_name, version_number, model_meta, "", ""); if (status != SUCCESS) { local_models_loader->Clear(); return status; } status = local_models_loader->AfterLoadModel(); if (status != SUCCESS) { local_models_loader->Clear(); return status; } models_loader[model_key] = local_models_loader; } status = Worker::GetInstance().StartServableInner(servable_name, version_number, models_loader, true); return status; } static void DeclareServable(const std::string &servable_name, const std::string &model_file, const std::string &model_type, bool with_batch_dim = false) { ModelMeta servable_meta; servable_meta.common_meta.servable_name = servable_name; servable_meta.common_meta.model_key = model_file; servable_meta.common_meta.with_batch_dim = with_batch_dim; servable_meta.local_meta.model_files = {model_file}; servable_meta.local_meta.SetModelFormat(model_type); // declare_servable ServableRegister::Instance().DeclareModel(servable_meta); } static Status RegisterMethod(const std::string &servable_name, const std::string &method_file, const std::string &method_name, const std::vector &input_names, const std::vector &output_names, size_t servable_input_count, size_t servable_output_count) { auto model_key = method_file; auto status = ServableRegister::Instance().RegisterInputOutputInfo(model_key, servable_input_count, servable_output_count); if (status != SUCCESS) { return status; } MethodSignature method_signature; method_signature.servable_name = servable_name; method_signature.method_name = method_name; method_signature.inputs = input_names; method_signature.outputs = output_names; // method input 0 and input 1 as servable input std::vector> model_input = {{0, 0}, {0, 1}}; method_signature.AddStageModel(model_key, model_input, 0, ""); // servable output as method output std::vector> return_output = {{1, 0}}; method_signature.SetReturn(return_output); ServableRegister::Instance().RegisterMethod(method_signature); return SUCCESS; } std::string servable_dir_; std::string servable_name_; int version_number_ = 0; std::string model_file_; std::string model_name_path_; std::string version_number_path_; std::string servable_name_path_; std::set servable_dir_list_; std::set model_name_path_list_; std::set version_number_path_list_; std::set servable_name_path_list_; }; class TestMasterWorkerClient : public TestMasterWorker { public: TestMasterWorkerClient() = default; static void InitTensor(proto::Tensor *tensor, const std::vector &shape, proto::DataType data_type, const void *data, size_t data_size) { MSI_EXCEPTION_IF_NULL(tensor); tensor->set_dtype(data_type); auto proto_shape = tensor->mutable_shape(); for (auto item : shape) { proto_shape->add_dims(item); } tensor->set_data(data, data_size); } static std::vector InitOneInstanceRequest(proto::PredictRequest *request, const std::string &servable_name, const std::string &method_name, int version_number) { MSI_EXCEPTION_IF_NULL(request); auto request_servable_spec = request->mutable_servable_spec(); request_servable_spec->set_name(servable_name); request_servable_spec->set_method_name(method_name); request_servable_spec->set_version_number(version_number); std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { y_data.push_back(x1_data[i] + x2_data[i]); } auto instance = request->add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2 InitTensor(&input_map["x2"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float)); return y_data; } template static std::vector> InitMultiInstancesRequest(proto::PredictRequest *request, const std::string &servable_name, const std::string &method_name, int version_number, size_t instances_count) { MSI_EXCEPTION_IF_NULL(request); auto request_servable_spec = request->mutable_servable_spec(); request_servable_spec->set_name(servable_name); request_servable_spec->set_method_name(method_name); request_servable_spec->set_version_number(version_number); auto data_type = proto::MS_FLOAT32; if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) { data_type = proto::MS_INT32; } std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data_org = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data_org = {6.6, 7.7, 8.8, 9.9}; std::vector x1_data; std::vector x2_data; std::vector y_data; for (size_t i = 0; i < x1_data_org.size(); i++) { x1_data.push_back(static_cast(x1_data_org[i] * (k + 1))); x2_data.push_back(static_cast(x2_data_org[i] * (k + 1))); y_data.push_back(static_cast(x1_data[i] + x2_data[i])); } y_data_list.push_back(y_data); auto instance = request->add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT)); // input x2 InitTensor(&input_map["x2"], {2, 2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT)); } return y_data_list; } template static std::vector> InitMultiInstancesShape2Request(proto::PredictRequest *request, const std::string &servable_name, const std::string &method_name, int version_number, size_t instances_count) { MSI_EXCEPTION_IF_NULL(request); auto request_servable_spec = request->mutable_servable_spec(); request_servable_spec->set_name(servable_name); request_servable_spec->set_method_name(method_name); request_servable_spec->set_version_number(version_number); auto data_type = proto::MS_FLOAT32; if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) { data_type = proto::MS_INT32; } std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data_org = {1.1, 2.2}; std::vector x2_data_org = {8.8, 9.9}; std::vector x1_data; std::vector x2_data; std::vector y_data; for (size_t i = 0; i < x1_data_org.size(); i++) { x1_data.push_back(static_cast(x1_data_org[i] * (k + 1))); x2_data.push_back(static_cast(x2_data_org[i] * (k + 1))); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request->add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT)); // input x2 InitTensor(&input_map["x2"], {2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT)); } return y_data_list; } template static void CheckMultiInstanceResult(const proto::PredictReply &reply, const std::vector> &y_data_list, size_t instances_count) { // checkout output ASSERT_EQ(reply.instances_size(), instances_count); ASSERT_EQ(reply.error_msg_size(), 0); auto data_type = proto::MS_FLOAT32; if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) { data_type = proto::MS_INT32; } std::vector shape; if (y_data_list[0].size() == 4) { shape = {2, 2}; } else { shape = {2}; } for (size_t k = 0; k < instances_count; k++) { auto &output_instance = reply.instances(k); ASSERT_EQ(output_instance.items_size(), 1); auto &output_items = output_instance.items(); ASSERT_EQ(output_items.begin()->first, "y"); auto &output_tensor = output_items.begin()->second; CheckTensor(output_tensor, shape, data_type, y_data_list[k].data(), y_data_list[k].size() * sizeof(OUT_DT)); } } template static void CheckInstanceResult(const proto::PredictReply &reply, const std::vector &y_data) { // checkout output ASSERT_EQ(reply.instances_size(), 1); ASSERT_EQ(reply.error_msg_size(), 0); auto data_type = proto::MS_FLOAT32; if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) { data_type = proto::MS_INT32; } std::vector shape; if (y_data.size() == 4) { shape = {2, 2}; } else { shape = {2}; } auto &output_instance = reply.instances(0); ASSERT_EQ(output_instance.items_size(), 1); auto &output_items = output_instance.items(); ASSERT_EQ(output_items.begin()->first, "y"); auto &output_tensor = output_items.begin()->second; CheckTensor(output_tensor, shape, data_type, y_data.data(), y_data.size() * sizeof(OUT_DT)); } static void CheckTensor(const proto::Tensor &output_tensor, const std::vector &shape, proto::DataType data_type, const void *data, size_t data_size) { EXPECT_EQ(output_tensor.dtype(), data_type); // check shape [2,2] auto &output_tensor_shape = output_tensor.shape(); ASSERT_EQ(output_tensor_shape.dims_size(), shape.size()); std::vector proto_shape; for (size_t i = 0; i < output_tensor_shape.dims_size(); i++) { proto_shape.push_back(output_tensor_shape.dims(i)); } EXPECT_EQ(proto_shape, shape); // check data ASSERT_EQ(output_tensor.data().size(), data_size); switch (data_type) { case proto::MS_FLOAT32: { auto data_len = data_size / sizeof(float); auto real_data = reinterpret_cast(output_tensor.data().data()); auto expect_data = reinterpret_cast(data); for (size_t i = 0; i < data_len; i++) { EXPECT_EQ(real_data[i], expect_data[i]); if (real_data[i] != expect_data[i]) { break; } } break; } case proto::MS_INT32: { auto data_len = data_size / sizeof(int32_t); auto real_data = reinterpret_cast(output_tensor.data().data()); auto expect_data = reinterpret_cast(data); for (size_t i = 0; i < data_len; i++) { EXPECT_EQ(real_data[i], expect_data[i]); if (real_data[i] != expect_data[i]) { break; } } break; } default: FAIL(); } } static grpc::Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) { MSWorkerImpl impl; auto promise = std::make_shared>(); auto future = promise->get_future(); PredictOnFinish callback = [promise]() { promise->set_value(); }; impl.PredictAsync(&request, reply, callback); future.get(); return grpc::Status::OK; } }; } // namespace serving } // namespace mindspore #endif // MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H ================================================ FILE: tests/ut/cpp/runtest.sh ================================================ #!/bin/bash # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e BASEPATH=$( cd "$(dirname "$0")" pwd ) PROJECT_PATH=${BASEPATH}/../../.. if [ $BUILD_PATH ]; then echo "BUILD_PATH = $BUILD_PATH" else BUILD_PATH=${PROJECT_PATH}/build echo "BUILD_PATH = $BUILD_PATH" fi cd ${BUILD_PATH}/mindspore_serving/tests/ut/cpp export LD_LIBRARY_PATH=${BUILD_PATH}/mindspore_serving/tests/ut/cpp:${LD_LIBRARY_PATH} echo "LD_LIBRARY_PATH = $LD_LIBRARY_PATH" if [ $# -gt 0 ]; then ./serving_ut --gtest_filter=$1 else ./serving_ut fi RET=$? cd - exit ${RET} ================================================ FILE: tests/ut/cpp/tests/test_agent_config_acquire.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/common_test.h" #include "common/tensor_base.h" #define private public #include "worker/distributed_worker/distributed_process/distributed_process.h" #include "worker/distributed_worker/notify_distributed/notify_worker.h" #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestAgentConfigAcquire : public UT::Common { public: TestAgentConfigAcquire() = default; virtual void SetUp() {} virtual void TearDown() { UT::Common::TearDown(); } }; TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_success) { std::shared_ptr servable = std::make_shared(); std::string rank_table_content = "rank table content"; CommonModelMeta commonServableMeta; commonServableMeta.servable_name = "servable_name"; commonServableMeta.model_key = "model_key"; commonServableMeta.outputs_count[0] = 1; commonServableMeta.inputs_count[0] = 1; commonServableMeta.with_batch_dim = false; commonServableMeta.without_batch_dim_inputs.push_back(8); DistributedModelMeta distributedServableMeta; distributedServableMeta.stage_size = 8; distributedServableMeta.rank_size = 8; OneRankConfig oneRankConfig; oneRankConfig.ip = "1.1.1.1"; oneRankConfig.device_id = 0; servable->config_.rank_table_content = rank_table_content; servable->config_.common_meta = commonServableMeta; servable->config_.distributed_meta = distributedServableMeta; servable->config_.rank_list.push_back(oneRankConfig); servable->config_loaded_ = true; const std::string server_address = "any_addr"; MSDistributedImpl mSDistributedImpl(servable, server_address); grpc::ServerContext context; const proto::AgentConfigAcquireRequest request; proto::AgentConfigAcquireReply reply; grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); ASSERT_EQ(status.error_code(), 0); DistributedServableConfig config; GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(reply, &config); ASSERT_EQ(config.rank_table_content, rank_table_content); ASSERT_EQ(config.common_meta.servable_name, "servable_name"); ASSERT_EQ(config.common_meta.model_key, "model_key"); ASSERT_EQ(config.common_meta.inputs_count.at(0), 1); ASSERT_EQ(config.common_meta.outputs_count.at(0), 1); ASSERT_EQ(config.common_meta.with_batch_dim, false); ASSERT_EQ(config.common_meta.without_batch_dim_inputs.size(), 1); ASSERT_EQ(config.common_meta.without_batch_dim_inputs.at(0), 8); ASSERT_EQ(config.distributed_meta.rank_size, 8); ASSERT_EQ(config.distributed_meta.stage_size, 8); ASSERT_EQ(config.rank_list.size(), 1); OneRankConfig tempRankConfig = config.rank_list.at(0); ASSERT_EQ(tempRankConfig.device_id, 0); ASSERT_EQ(tempRankConfig.ip, "1.1.1.1"); } TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_load_config_failed) { std::shared_ptr servable = std::make_shared(); servable->config_loaded_ = false; const std::string server_address = "any_addr"; MSDistributedImpl mSDistributedImpl(servable, server_address); grpc::ServerContext context; const proto::AgentConfigAcquireRequest request; proto::AgentConfigAcquireReply reply; const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); ASSERT_EQ(status.error_code(), 1); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_context.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "../common/common_test.h" #define private public #include "worker/inference/inference.h" #include "worker/inference/mindspore_model_wrap.h" #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestModelContext : public UT::Common { public: TestModelContext() = default; void Init(std::string file_name) { char *dir; dir = get_current_dir_name(); std::string file_path(dir); file_path += file_name; std::ofstream fp(file_path); fp << "model content"; fp.close(); model_file = file_path; free(dir); } virtual void SetUp() { setenv("SERVING_ENABLE_CPU_DEVICE", "1", 1); setenv("SERVING_ENABLE_GPU_DEVICE", "1", 1); } virtual void TearDown() { remove(model_file.c_str()); setenv("SERVING_ENABLE_CPU_DEVICE", "0", 1); setenv("SERVING_ENABLE_GPU_DEVICE", "0", 1); } std::string model_file; }; /// Feature: model context /// Description: ascend910 device with mindspore /// Expectation: the context has ascend910 and load success TEST_F(TestModelContext, test_ms_set_ascend910) { setenv("SERVING_ENABLE_CPU_DEVICE", "0", 1); setenv("SERVING_ENABLE_GPU_DEVICE", "0", 1); Init("tensor_add.mindir@ms_ascend"); ModelContext model_context; auto mindspore_wrap = InferenceLoader::Instance().CreateMindSporeInfer(); auto status = mindspore_wrap->LoadModelFromFile(serving::DeviceType::kDeviceTypeAscend, 0, {model_file}, serving::kMindIR, false, {}, model_context, {}, {}, {}, false); ASSERT_TRUE(status.IsSuccess()); } /// Feature: model context /// Description: gpu device with lite /// Expectation: the context has gpu and load success TEST_F(TestModelContext, test_lite_set_gpu) { Init("tensor_add.mindir@lite_gpu_cpu"); ModelContext model_context; auto mindspore_wrap = InferenceLoader::Instance().CreateMindSporeInfer(); auto status = mindspore_wrap->LoadModelFromFile(serving::DeviceType::kDeviceTypeGpu, 0, {model_file}, serving::kMindIR, false, {}, model_context, {}, {}, {}, true); ASSERT_TRUE(status.IsSuccess()); } /// Feature: Model context /// Description: gpu cpu device with lite /// Expectation: the context has gpu and cpu and load success TEST_F(TestModelContext, test_lite_set_gpu_cpu) { Init("tensor_add.mindir@lite_gpu_cpu"); ModelContext model_context; DeviceInfo cpu_device_info{{"device_type", "cpu"}}; model_context.device_list.push_back(cpu_device_info); auto mindspore_wrap = InferenceLoader::Instance().CreateMindSporeInfer(); auto status = mindspore_wrap->LoadModelFromFile(serving::DeviceType::kDeviceTypeGpu, 0, {model_file}, serving::kMindIR, false, {}, model_context, {}, {}, {}, true); ASSERT_TRUE(status.IsSuccess()); } /// Feature: Model context /// Description: gpu cpu device with mindspore /// Expectation: the context only has gpu and load success TEST_F(TestModelContext, test_ms_set_gpu) { Init("tensor_add.mindir@ms_gpu"); ModelContext model_context; DeviceInfo cpu_device_info{{"device_type", "cpu"}}; model_context.device_list.push_back(cpu_device_info); auto mindspore_wrap = InferenceLoader::Instance().CreateMindSporeInfer(); auto status = mindspore_wrap->LoadModelFromFile(serving::DeviceType::kDeviceTypeGpu, 0, {model_file}, serving::kMindIR, false, {}, model_context, {}, {}, {}, false); ASSERT_TRUE(status.IsSuccess()); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_distributed_inference.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "gtest/gtest.h" #include "common/status.h" #include "proto/ms_agent.pb.h" #include "tests/ut/cpp/common/common_test.h" #include "common/grpc_client.h" #include "worker/distributed_worker/notify_agent/base_notify_agent.h" #define private public #include "common/exit_handle.h" #include "worker/distributed_worker/distributed_model_loader.h" #undef private namespace mindspore { namespace serving { struct AgentInferResult { int64_t prediction_time = 0; // milliseconds Status status = SUCCESS; int64_t error_code = 0; std::string error_msg = ""; }; class FakeNotifyAgent : public BaseNotifyAgent { public: explicit FakeNotifyAgent(int64_t prediction_time = 0, Status status = SUCCESS, int64_t error_code = 0, std::string error_msg = "") : prediction_time_(prediction_time), status_(status), error_code_(error_code), error_msg_(error_msg) {} ~FakeNotifyAgent() = default; Status Exit() override { return SUCCESS; } Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, AsyncPredictCallback callback) override { auto error_msg = reply->mutable_error_msg(); error_msg->set_error_code(error_code_); if (!error_msg_.empty()) { error_msg->set_error_msg(error_msg_); } auto prediction_time = prediction_time_; auto status = status_; auto predict = [prediction_time, status, callback]() { std::chrono::milliseconds dura(prediction_time); std::this_thread::sleep_for(dura); callback(status); }; std::thread t1(predict); t1.detach(); return SUCCESS; } private: int64_t prediction_time_; // milliseconds Status status_; int64_t error_code_; std::string error_msg_; }; class TestDistributedInference : public UT::Common { public: TestDistributedInference() = default; ~TestDistributedInference() = default; void InitDistributedServable(std::shared_ptr servable, size_t rank_size, size_t stage_size, bool is_running, bool is_loaded) { ExitSignalHandle::Instance().is_running_ = is_running; servable->model_loaded_ = is_loaded; servable->config_.distributed_meta.rank_size = rank_size; servable->config_.distributed_meta.stage_size = stage_size; } void InitAgentSpecMap(std::shared_ptr servable, const std::vector &result_list) { for (size_t rank_id = 0; rank_id < result_list.size(); ++rank_id) { const auto &result = result_list[rank_id]; DistributedAgentContext agent_context; agent_context.notify_agent_ = std::make_shared(result.prediction_time, result.status, result.error_code, result.error_msg); servable->agent_spec_map_.insert({rank_id, agent_context}); } } }; TEST_F(TestDistributedInference, test_agent_8_stage_1) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 8, 1, true, true); std::vector result_list(8); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_agent_4) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 4, 1, true, true); std::vector result_list(4); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_agent_32_stage_1) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 1, true, true); std::vector result_list(32); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_agent_32_stage_2) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 2, true, true); std::vector result_list(32); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_agent_32_stage_4) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, true); std::vector result_list(32); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_agent_64_stage_8) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 64, 8, true, true); std::vector result_list(64); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestDistributedInference, test_output_nullptr) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, true); std::vector result_list(32); InitAgentSpecMap(servable, result_list); Status status; std::vector input, output; ASSERT_ANY_THROW({ status = servable->Predict(input, nullptr); }); ASSERT_EQ(status.StatusCode(), FAILED); } TEST_F(TestDistributedInference, test_agent_infer_more_than_10s) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, true); std::vector result_list(32); result_list[20].prediction_time = 11000; InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), FAILED); } TEST_F(TestDistributedInference, test_agent_exit) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, false, true); std::vector result_list(32); InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), FAILED); } TEST_F(TestDistributedInference, test_rank_size_not_equal_agent_num) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, true); std::vector result_list(12); InitAgentSpecMap(servable, result_list); Status status; std::vector input, output; ASSERT_ANY_THROW({ status = servable->Predict(input, &output); }); ASSERT_EQ(status.StatusCode(), FAILED); } TEST_F(TestDistributedInference, test_agent_reply_with_error_msg) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, true); std::vector result_list(32); result_list[10].error_msg = "failed"; result_list[10].error_code = 1; InitAgentSpecMap(servable, result_list); std::vector input, output; auto status = servable->Predict(input, &output); ASSERT_EQ(status.StatusCode(), FAILED); } TEST_F(TestDistributedInference, test_model_not_loaded) { auto servable = std::make_shared(); servable->model_key_ = "test_distributed_model_key"; InitDistributedServable(servable, 32, 4, true, false); std::vector result_list(32); InitAgentSpecMap(servable, result_list); Status status; std::vector input, output; ASSERT_ANY_THROW({ status = servable->Predict(input, &output); }); ASSERT_EQ(status.StatusCode(), FAILED); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_init_config_on_start_up.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/common_test.h" #include "common/tensor_base.h" #define private public #include "worker/distributed_worker/distributed_model_loader.h" #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestParseRankTableFile : public UT::Common { public: TestParseRankTableFile() = default; virtual void SetUp() {} virtual void TearDown() { for (auto &item : config_file_list_) { remove(item.c_str()); } UT::Common::TearDown(); } std::set config_file_list_; }; TEST_F(TestParseRankTableFile, test_init_config_on_startup_empty_file_failed) { std::string empty_rank_table_file = "empty_rank_table_file"; std::ofstream fp(empty_rank_table_file); fp << "empty rank table file"; fp.close(); config_file_list_.emplace(empty_rank_table_file); auto servable = std::make_shared(); auto status = servable->InitConfigOnStartup(empty_rank_table_file); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_init_config_on_startup_success) { nlohmann::json rank_table_server_list = R"( { "server_list": [ { "server_id": "10.155.111.140", "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; std::string rank_table_file = "rank_table_file"; std::ofstream fp(rank_table_file); fp << rank_table_server_list; fp.close(); config_file_list_.emplace(rank_table_file); auto servable = std::make_shared(); auto status = servable->InitConfigOnStartup(rank_table_file); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_server_list_success) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}, {"device_id": "2","device_ip": "192.3.27.6","rank_id": "2"}, {"device_id": "3","device_ip": "192.4.27.6","rank_id": "3"}, {"device_id": "4","device_ip": "192.1.27.7","rank_id": "4"}, {"device_id": "5","device_ip": "192.2.27.7","rank_id": "5"}, {"device_id": "6","device_ip": "192.3.27.7","rank_id": "6"}, {"device_id": "7","device_ip": "192.4.27.7","rank_id": "7"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), SUCCESS); ASSERT_EQ(servable->config_.rank_list.size(), 8); uint32_t expect_device_id = 0; for (auto &one_rank_config : servable->config_.rank_list) { std::string server_ip = one_rank_config.ip; uint32_t device_id = one_rank_config.device_id; ASSERT_EQ(server_ip, "10.155.111.140"); ASSERT_EQ(device_id, expect_device_id); expect_device_id++; } } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_not_server_list_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_invalid_server_list_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": "0", "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_empty_server_list_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_not_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_invalid_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": [], "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_empty_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "", "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_not_device_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_invalid_device_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "device": "dsfds", "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_empty_device_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "device": [], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_not_device_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "device": [ {"device_ip": "192.1.27.6","rank_id": "0"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_invalid_device_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "", "device": [ {"device_id": "1wdb","device_ip": "192.1.27.6","rank_id": "0"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_not_rank_id_failed) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "10.155.111.140", "device": [ {"device_id": "0","device_ip": "192.1.27.6"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_invalid_rank_id_failed1) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "", "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0wer"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_invalid_rank_id_failed2) { nlohmann::json rank_table_server_list = R"( { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "", "device": [ {"device_id": "0","device_ip": "192.1.27.6","rank_id": "5"}], "host_nic_ip": "reserve" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithServerList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_success) { nlohmann::json rank_table_group_list = R"( { "board_id": "0x0000", "chip_info": "910", "deploy_mode": "lab", "group_count": "1", "group_list": [ { "device_num": "2", "server_num": "1", "group_name": "", "instance_count": "2", "instance_list": [ { "devices": [{"device_id": "0","device_ip": "192.1.27.6"}], "rank_id": "0", "server_id": "10.155.111.140" }, { "devices": [{"device_id": "1","device_ip": "192.2.27.6"}], "rank_id": "1", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_group_list); ASSERT_EQ(status.StatusCode(), SUCCESS); ASSERT_EQ(servable->config_.rank_list.size(), 2); uint32_t expect_device_id = 0; for (auto &one_rank_config : servable->config_.rank_list) { std::string server_ip = one_rank_config.ip; uint32_t device_id = one_rank_config.device_id; ASSERT_EQ(server_ip, "10.155.111.140"); ASSERT_EQ(device_id, expect_device_id); expect_device_id++; } } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_not_group_list_failed) { nlohmann::json rank_table_group_list = R"( { "board_id": "0x0000", "chip_info": "910", "deploy_mode": "lab", "group_count": "1", "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_group_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_invalid_group_list_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "chip_info": "910", "group_count": "1", "group_list": "0", "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_empty_group_list_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "chip_info": "910", "group_count": "1", "group_list": [], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_not_instance_list_failed) { nlohmann::json rank_table_group_list = R"( { "board_id": "0x0000", "chip_info": "910", "deploy_mode": "lab", "group_count": "1", "group_list": [ { "server_num": "1" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_group_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_invalid_instance_list_failed) { nlohmann::json rank_table_group_list = R"( { "board_id": "0x0000", "chip_info": "910", "deploy_mode": "lab", "group_count": "1", "group_list": [ { "server_num": "1", "instance_list": "0" } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_group_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_empty_instance_list_failed) { nlohmann::json rank_table_group_list = R"( { "board_id": "0x0000", "chip_info": "910", "deploy_mode": "lab", "group_count": "1", "group_list": [ { "server_num": "1", "instance_list": [] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_group_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_not_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0","device_ip": "192.1.27.6"}], "rank_id": "0" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_invalid_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0","device_ip": "192.1.27.6"}], "rank_id": "0", "server_id": [] } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_empty_server_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0","device_ip": "192.1.27.6"}], "rank_id": "0", "server_id": "" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_not_devices_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "rank_id": "0", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_invalid_devices_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": "rtrt", "rank_id": "0", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_empty_devices_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [], "rank_id": "0", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_not_device_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_ip": "192.1.27.6"}], "rank_id": "0", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_invalid_device_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "wd1gt2", "device_ip": "192.1.27.6"}], "rank_id": "0", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_not_rank_id_failed) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0", "device_ip": "192.1.27.6"}], "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_invalid_rank_id_failed1) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0", "device_ip": "192.1.27.6"}], "rank_id": "tfdg5", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } TEST_F(TestParseRankTableFile, test_parse_rank_table_file_with_group_list_invalid_rank_id_failed2) { nlohmann::json rank_table_server_list = R"( { "board_id": "0x0000", "group_list": [ { "instance_count": "1", "instance_list": [ { "devices": [{"device_id": "0", "device_ip": "192.1.27.6"}], "rank_id": "7", "server_id": "10.155.111.140" } ] } ], "status": "completed" } )"_json; auto servable = std::make_shared(); auto status = servable->ParserRankTableWithGroupList("rank_table_file", rank_table_server_list); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_master_worker.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/ut/cpp/common/test_servable_common.h" #define private public #undef private using std::string; using std::vector; namespace mindspore { namespace serving { TEST_F(TestMasterWorkerClient, test_master_worker_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 0); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.instances_size(), 1); ASSERT_EQ(reply.error_msg_size(), 0); auto &output_instance = reply.instances(0); ASSERT_EQ(output_instance.items_size(), 1); auto &output_items = output_instance.items(); ASSERT_EQ(output_items.begin()->first, "y"); auto &output_tensor = output_items.begin()->second; CheckTensor(output_tensor, {2, 2}, proto::MS_FLOAT32, y_data.data(), y_data.size() * sizeof(float)); } TEST_F(TestMasterWorkerClient, test_master_worker_success_version_number_1_request_version_1) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 1); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.instances_size(), 1); ASSERT_EQ(reply.error_msg_size(), 0); auto &output_instance = reply.instances(0); ASSERT_EQ(output_instance.items_size(), 1); auto &output_items = output_instance.items(); ASSERT_EQ(output_items.begin()->first, "y"); auto &output_tensor = output_items.begin()->second; CheckTensor(output_tensor, {2, 2}, proto::MS_FLOAT32, y_data.data(), y_data.size() * sizeof(float)); } TEST_F(TestMasterWorkerClient, test_master_worker_success_version_number_2_request_version_2) { Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); RegisterAddServable(); // start_servable auto status = StartServable(servable_dir_, servable_name_, 2); ASSERT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 2); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckInstanceResult(reply, y_data); } TEST_F(TestMasterWorkerClient, test_master_worker_success_version_number_2_request_lastest) { Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); RegisterAddServable(); // start_servable auto status = StartServable(servable_dir_, servable_name_, 2); ASSERT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 0); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckInstanceResult(reply, y_data); } TEST_F(TestMasterWorkerClient, test_master_worker_success_multi_version_number_1_2_request_lastest) { auto servable_dir = std::string(test_info_->test_case_name()) + "_test_servable_dir"; Init(servable_dir, "test_servable", 1, "test_add.mindir"); Init(servable_dir, "test_servable", 2, "test_add.mindir"); RegisterAddServable(); // start_servable auto status = StartServable(servable_dir_, servable_name_, 2); ASSERT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 0); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckInstanceResult(reply, y_data); } TEST_F(TestMasterWorkerClient, test_master_worker_success_version_number_1_2_request_2) { auto servable_dir = std::string(test_info_->test_case_name()) + "_test_servable_dir"; Init(servable_dir, "test_servable", 1, "test_add.mindir"); Init(servable_dir, "test_servable", 2, "test_add.mindir"); RegisterAddServable(); // start_servable auto status = StartServable(servable_dir_, servable_name_, 2); ASSERT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 2); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckInstanceResult(reply, y_data); } TEST_F(TestMasterWorkerClient, test_master_worker_success_version_number_1_2_request_1_failed) { auto servable_dir = std::string(test_info_->test_case_name()) + "_test_servable_dir"; Init(servable_dir, "test_servable", 1, "test_add.mindir"); Init(servable_dir, "test_servable", 2, "test_add.mindir"); RegisterAddServable(); // start_servable auto status = StartServable(servable_dir_, servable_name_, 2); ASSERT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; auto y_data = InitOneInstanceRequest(&request, servable_name_, "add_common", 1); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Cannot find servable match servable"); } TEST_F(TestMasterWorkerClient, test_master_worker_three_instance_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32 --> servable float32-float32, shape [2, 2] auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_common", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestMasterWorkerClient, test_master_worker_input_size_not_match_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); size_t instances_count = 3; std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2}; std::vector x2_data = {1.2, 2.3}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2 InitTensor(&input_map["x2"], {2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), instances_count); } TEST_F(TestMasterWorkerClient, test_master_worker_with_batch_dim_true_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(true); // with_batch_dim = true // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32 --> servable float32-float32, shape [2] auto y_data_list = InitMultiInstancesShape2Request(&request, servable_name_, "add_common", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestMasterWorkerClient, test_master_worker_with_batch_dim_true_input_size_not_match_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(true); // with_batch_dim = true // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // shape [2,2] not match required shape [2] as with_batch_dim = true auto y_data = InitMultiInstancesRequest(&request, servable_name_, "add_common", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), instances_count); } TEST_F(TestMasterWorkerClient, test_master_worker_error_servable_name) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid servable name auto y_data = InitMultiInstancesRequest(&request, servable_name_ + "_error", "add_common", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Servable test_servable_error is not declared"); } TEST_F(TestMasterWorkerClient, test_master_worker_error_method_name) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid method name auto y_data = InitMultiInstancesRequest(&request, servable_name_, "add_common_error", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Method add_common_error is not registered for servable test_servable"); } TEST_F(TestMasterWorkerClient, test_master_worker_error_version_number) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto y_data = InitMultiInstancesRequest(&request, servable_name_, "add_common", 2, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Cannot find servable match servable"); } TEST_F(TestMasterWorkerClient, test_master_worker_invalid_input_name) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x3, expected is x2 InitTensor(&input_map["x3"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Cannot find input x2 in instance input"); } TEST_F(TestMasterWorkerClient, test_master_worker_three_instance_one_input_invalid_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32 --> servable float32-float32, shape [2, 2] auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_common", 0, instances_count); auto items = request.mutable_instances(1)->mutable_items(); auto it = items->find("x2"); ASSERT_TRUE(it != items->end()); items->erase(it); // erase x2 input proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); ExpectContainMsg(reply.error_msg(0).error_msg(), "Cannot find input x2 in instance input"); } TEST_F(TestMasterWorkerClient, test_master_worker_extra_input_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2 InitTensor(&input_map["x2"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float)); // extra input x3 InitTensor(&input_map["x3"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestMasterWorkerClient, test_master_worker_invalid_input_datatype_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2, invalid data type InitTensor(&input_map["x2"], {2, 2}, proto::MS_INT32, x2_data.data(), x2_data.size() * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 3); ExpectContainMsg(reply.error_msg(0).error_msg(), "Given model input 1 data type"); } TEST_F(TestMasterWorkerClient, test_master_worker_with_batch_dim_true_invalid_input_datatype_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(true); // with_batch_dim=true // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2}; std::vector x2_data = {1.2, 2.3}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2, invalid data type InitTensor(&input_map["x2"], {2}, proto::MS_INT32, x2_data.data(), x2_data.size() * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 3); ExpectContainMsg(reply.error_msg(0).error_msg(), "Given model input 1 data type"); } TEST_F(TestMasterWorkerClient, test_master_worker_invalid_input_datasize_not_match_shape_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2, invalid data size InitTensor(&input_map["x2"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), (x2_data.size() - 1) * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 1); // proto parse check failed ExpectContainMsg(reply.error_msg(0).error_msg(), "Tensor check failed: input data size"); } TEST_F(TestMasterWorkerClient, test_master_worker_invalid_input_datasize_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(); // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2, 3.3, 4.4}; std::vector x2_data = {1.2, 2.3, 3.4, 4.5}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2, invalid data size InitTensor(&input_map["x2"], {2, 1}, proto::MS_FLOAT32, x2_data.data(), (x2_data.size() - 2) * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 3); ExpectContainMsg(reply.error_msg(0).error_msg(), "Given model input 1 size 8"); } TEST_F(TestMasterWorkerClient, test_master_worker_with_batch_dim_true_invalid_input_datasize_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); RegisterAddServable(true); // with_batch_dim=true // start_servable StartAddServable(); // run servable proto::PredictRequest request; size_t instances_count = 3; // invalid version_number auto request_servable_spec = request.mutable_servable_spec(); request_servable_spec->set_name(servable_name_); request_servable_spec->set_method_name("add_common"); request_servable_spec->set_version_number(0); std::vector> y_data_list; for (size_t k = 0; k < instances_count; k++) { std::vector x1_data = {1.1, 2.2}; std::vector x2_data = {1.2, 2.3}; std::vector y_data; for (size_t i = 0; i < x1_data.size(); i++) { x1_data[i] *= (k + 1); x2_data[i] *= (k + 1); y_data.push_back(x1_data[i] + x2_data[i]); } y_data_list.push_back(y_data); auto instance = request.add_instances(); auto &input_map = (*instance->mutable_items()); // input x1 InitTensor(&input_map["x1"], {2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float)); // input x2, invalid data size InitTensor(&input_map["x2"], {1}, proto::MS_FLOAT32, x2_data.data(), (x2_data.size() - 1) * sizeof(float)); } proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), 3); ExpectContainMsg(reply.error_msg(0).error_msg(), "Given model input 1 size 4"); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_model_thread.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/common_test.h" #include "master/server.h" #include "common/tensor_base.h" #define private public #include "master/model_thread.h" #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestModelThead : public UT::Common { public: TestModelThead() = default; }; class MS_API TestNotify : public BaseNotifyWorker { public: explicit TestNotify(proto::PredictReply *reply) { if (reply) { reply_ = *reply; } } ~TestNotify() override = default; Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) override; proto::PredictReply reply_; }; Status TestNotify::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, const PredictOnFinish &on_finish) { *reply = reply_; on_finish(); return SUCCESS; } std::shared_ptr InitWorkerContext(proto::PredictReply *reply = nullptr) { std::shared_ptr worker_context = std::make_shared(); std::shared_ptr notify = std::make_shared(reply); WorkerRegSpec spec; spec.worker_pid = 1; spec.servable_spec.servable_name = "test_servable"; spec.servable_spec.version_number = 1; spec.servable_spec.batch_size = 1; spec.servable_spec.methods.push_back(ServableMethodInfo{"add_cast", {}}); worker_context->OnWorkerRegRequest(spec, notify); return worker_context; } TEST_F(TestModelThead, AddWorker) { ServableMethodInfo method_info; method_info.name = "add_cast"; ModelThread thread("test_servable", "add_cast", 0, 1, method_info); uint64_t pid = 1; std::shared_ptr worker_context = InitWorkerContext(); Status status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), FAILED); pid = 2; status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestModelThead, DelWorker) { ServableMethodInfo method_info; method_info.name = "add_cast"; ModelThread thread("test_servable", "add_cast", 0, 1, method_info); uint64_t pid = 1; Status status = thread.DelWorker(pid); ASSERT_EQ(status.StatusCode(), FAILED); std::shared_ptr worker_context = InitWorkerContext(); status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); status = thread.DelWorker(pid); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestModelThead, Dispatch) { ServableMethodInfo method_info; method_info.name = "add_cast"; ModelThread thread("test_servable", "add_cast", 0, 1, method_info); uint64_t pid = 1; std::shared_ptr worker_context = InitWorkerContext(); Status status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); proto::PredictRequest request; request.mutable_servable_spec()->set_name("test_servable"); request.mutable_servable_spec()->set_version_number(0); request.mutable_servable_spec()->set_method_name("add_cast"); proto::Instance instance; auto proto_instance = request.add_instances(); *proto_instance->mutable_items() = instance.items(); proto::PredictReply reply; PredictOnFinish callback = []() {}; status = thread.DispatchAsync(request, &reply, callback); ASSERT_EQ(status.StatusCode(), SUCCESS); status = thread.DelWorker(pid); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestModelThead, Dispatch1) { ServableMethodInfo method_info; method_info.name = "add_cast"; ModelThread thread("test_servable", "add_cast", 0, 1, method_info); uint64_t pid = 1; std::shared_ptr worker_context = InitWorkerContext(); proto::PredictRequest request; request.mutable_servable_spec()->set_name("test_servable"); request.mutable_servable_spec()->set_version_number(0); request.mutable_servable_spec()->set_method_name("add_cast"); proto::Instance instance; auto proto_instance = request.add_instances(); *proto_instance->mutable_items() = instance.items(); proto::PredictReply reply; PredictOnFinish callback = []() {}; Status status = thread.DispatchAsync(request, &reply, callback); ASSERT_NE(status.StatusCode(), SUCCESS); status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); status = thread.DelWorker(pid); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestModelThead, Commit) { ServableMethodInfo method_info; method_info.name = "add_cast"; ModelThread thread("test_servable", "add_cast", 0, 1, method_info); uint64_t pid = 1; proto::Instance instance; proto::PredictReply reply; auto proto_instance1 = reply.add_instances(); *proto_instance1->mutable_items() = instance.items(); proto::ErrorMsg msg; auto proto_instance2 = reply.add_error_msg(); *proto_instance2 = msg; std::shared_ptr worker_context = InitWorkerContext(&reply); Status status = thread.AddWorker(pid, worker_context); ASSERT_EQ(status.StatusCode(), SUCCESS); proto::PredictRequest request; auto proto_instance = request.add_instances(); *proto_instance->mutable_items() = instance.items(); request.mutable_servable_spec()->set_name("test_servable"); request.mutable_servable_spec()->set_version_number(0); request.mutable_servable_spec()->set_method_name("add_cast"); bool flag = false; PredictOnFinish callback = [&flag]() { flag = true; }; status = thread.DispatchAsync(request, &reply, callback); ASSERT_EQ(status.StatusCode(), SUCCESS); ASSERT_EQ(status.StatusCode(), SUCCESS); ASSERT_EQ(flag, true); status = thread.DelWorker(pid); ASSERT_EQ(status.StatusCode(), SUCCESS); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_parse_restful.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/common_test.h" #include "master/server.h" #include "common/tensor_base.h" #define private public #include "master/restful/http_process.h" #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestParseInput : public UT::Common { public: TestParseInput() = default; }; class TestParseReply : public UT::Common { public: TestParseReply() = default; }; TEST_F(TestParseInput, test_parse_SUCCESS) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} }, { "key_tag":"tensor", "key_int": [1,2,3], "key_bool":[[true, false], [false, true]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"], "key_bytes":{"b64":"dXRfdGVzdA=="} }, { "key_tag":"b64", "key_str_format1":"ut_test", "key_str_foramt2":{"b64":"dXRfdGVzdA==", "type":"str"}, "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":[3,2]}, "key_bytes_fp16":{"b64":"ZjxmQJpCZkQ=", "type":"fp16", "shape":[2,2]}, "key_bytes_bool":{"b64":"AQA=", "type":"bool", "shape":[1,2]} } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); ASSERT_EQ(predict_request.instances().size(), 3); for (int32_t k = 0; k < predict_request.instances().size(); k++) { auto &cur_instance = predict_request.instances(k); auto &items = cur_instance.items(); if (k == 0) { ASSERT_EQ(items.size(), 6); for (const auto &item : items) { ProtoTensor pb_tensor(const_cast(&item.second)); if (item.first == "key_int") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Int32); const int32_t *data = reinterpret_cast(pb_tensor.data()); ASSERT_EQ(*data, 1); } else if (item.first == "key_bool") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Bool); const bool *data = reinterpret_cast(pb_tensor.data()); ASSERT_EQ(*data, false); } else if (item.first == "key_float") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Float32); const float *data = reinterpret_cast(pb_tensor.data()); ASSERT_FLOAT_EQ(*data, 2.3); } else if (item.first == "key_str") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_String); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } else if (item.first == "key_bytes") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Bytes); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } } } else if (k == 1) { ASSERT_EQ(items.size(), 6); for (const auto &item : items) { ProtoTensor pb_tensor(const_cast(&item.second)); auto shape = pb_tensor.shape(); if (item.first == "key_int") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Int32); ASSERT_EQ(shape.size(), 1); ASSERT_EQ(shape[0], 3); vector expected_value = {1, 2, 3}; for (int i = 0; i < 3; i++) { const int32_t *data = reinterpret_cast(pb_tensor.data()) + i; ASSERT_EQ(*data, expected_value[i]); } } else if (item.first == "key_bool") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Bool); ASSERT_EQ(shape.size(), 2); ASSERT_EQ(shape[0], 2); ASSERT_EQ(shape[1], 2); vector> expected_value = {{true, false}, {false, true}}; for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { const bool *data = reinterpret_cast(pb_tensor.data()) + i * 2 + j; ASSERT_EQ(*data, expected_value[i][j]); } } } else if (item.first == "key_float") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Float32); ASSERT_EQ(shape.size(), 2); ASSERT_EQ(shape[0], 1); ASSERT_EQ(shape[1], 2); vector> expected_value = {{1.1, 2.2}}; for (int i = 0; i < 1; i++) { for (int j = 0; j < 2; j++) { const float *data = reinterpret_cast(pb_tensor.data()) + i * 1 + j; ASSERT_FLOAT_EQ(*data, expected_value[i][j]); } } } else if (item.first == "key_str") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_String); ASSERT_EQ(shape.size(), 1); ASSERT_EQ(shape[0], 1); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } else if (item.first == "key_bytes") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Bytes); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } } } else if (k == 2) { ASSERT_EQ(items.size(), 6); for (const auto &item : items) { ProtoTensor pb_tensor(const_cast(&item.second)); auto shape = pb_tensor.shape(); if (item.first == "key_str_format1") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_String); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } else if (item.first == "key_str_format2") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_String); auto str_nums = pb_tensor.bytes_data_size(); ASSERT_EQ(str_nums, 1); std::string value; size_t length; const uint8_t *ptr = nullptr; pb_tensor.get_bytes_data(0, &ptr, &length); value.resize(length); memcpy_s(value.data(), length, reinterpret_cast(ptr), length); ASSERT_EQ(value, "ut_test"); } else if (item.first == "key_bytes_int16") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Int16); ASSERT_EQ(shape.size(), 2); ASSERT_EQ(shape[0], 3); ASSERT_EQ(shape[1], 2); vector> expected_value = {{1, 2}, {2, 3}, {3, 4}}; for (int i = 0; i < 3; i++) { for (int j = 0; j < 2; j++) { const int16_t *data = reinterpret_cast(pb_tensor.data()) + i * 2 + j; ASSERT_FLOAT_EQ(*data, expected_value[i][j]); } } } else if (item.first == "key_bytes_fp16") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Float16); ASSERT_EQ(shape.size(), 2); ASSERT_EQ(shape[0], 2); ASSERT_EQ(shape[1], 2); } else if (item.first == "key_bytes_bool") { ASSERT_EQ(pb_tensor.data_type(), DataType::kMSI_Bool); ASSERT_EQ(shape.size(), 2); ASSERT_EQ(shape[0], 1); ASSERT_EQ(shape[1], 2); vector> expected_value = {{true, false}}; for (int i = 0; i < 1; i++) { for (int j = 0; j < 2; j++) { const bool *data = reinterpret_cast(pb_tensor.data()) + i * 2 + j; ASSERT_FLOAT_EQ(*data, expected_value[i][j]); } } } } } } } TEST_F(TestParseInput, test_instances_empty_FAIL) { nlohmann::json js = R"( {"": { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_instances_incorrect_FAIL) { nlohmann::json js = R"( {"instance": { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_key_empty_FAIL) { nlohmann::json js = R"( {"instances": { "":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_value_empty_SUCCESS) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_unknown_key_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes", "type1":"bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_nob64_key_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"base64": "dXRfdGVzdA==", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_illegal_b64value_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"base64": "dXRfdGVzdA", "type": "bytes"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_unknown_type_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"base64": "dXRfdGVzdA==", "type": "INt"} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_error_shape_format_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":3} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_error_shape_format2_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":[[3],[2]]} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_error_shape_value_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":[3.0,2.0]} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_error_shape_value2_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":[3,3]} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_obj_error_shape_value3_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes_int16":{"b64":"AQACAAIAAwADAAQA", "type":"int16", "shape":[3,-2]} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_tensor_value_empty_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"tensor", "key_int": [], "key_bool":[[true, false], [false, true]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"], "key_bytes":{"b64":"dXRfdGVzdA=="} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_tensor_value_diff_type_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"tensor", "key_int": [1, 2.0], "key_bool":[[true, false], [false, true]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"], "key_bytes":{"b64":"dXRfdGVzdA=="} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_tensor_value_diff_dimention_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"tensor", "key_int": [1, 2], "key_bool":[[true, false], [false]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"], "key_bytes":{"b64":"dXRfdGVzdA=="} } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseInput, test_tensor_multi_object_FAIL) { nlohmann::json js = R"( {"instances": { "key_tag":"tensor", "key_int": [1, 2], "key_bool":[[true, false], [false, true]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"], "key_bytes":[{"b64":"dXRfdGVzdA=="}, {"b64":"dXRfdGVzdA=="}] } } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_NE(status.StatusCode(), SUCCESS); } TEST_F(TestParseReply, test_reply_SUCCESS) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} }, { "key_tag":"tensor", "key_int": [1,2,3], "key_bool":[[true, false], [false, true]], "key_float":[[1.1, 2.2]], "key_str":["ut_test"] } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status(INVALID_INPUTS); status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); nlohmann::json out_js; proto::PredictReply reply; auto instance_ptr = reply.add_instances(); auto &map_item = *(instance_ptr->mutable_items()); // test scalar: // scalar:key_int proto::Tensor tensor_int; ProtoTensor pb_tensor_int(&tensor_int); DataType type_int = kMSI_Int32; pb_tensor_int.set_data_type(type_int); pb_tensor_int.set_shape({1}); pb_tensor_int.resize_data(pb_tensor_int.GetTypeSize(type_int)); auto data_int = reinterpret_cast(pb_tensor_int.mutable_data()); *data_int = 1; map_item["key_int"] = tensor_int; // scalar: key_bool proto::Tensor tensor_bool; ProtoTensor pb_tensor_bool(&tensor_bool); DataType type_bool = kMSI_Bool; pb_tensor_bool.set_data_type(type_bool); pb_tensor_bool.resize_data(pb_tensor_bool.GetTypeSize(type_bool)); auto data_bool = reinterpret_cast(pb_tensor_bool.mutable_data()); *data_bool = false; map_item["key_bool"] = tensor_bool; // scalar: key_float proto::Tensor tensor_float; ProtoTensor pb_tensor_float(&tensor_float); DataType type_float = kMSI_Float32; pb_tensor_float.set_data_type(type_float); pb_tensor_float.set_shape({1}); pb_tensor_float.resize_data(pb_tensor_float.GetTypeSize(type_float)); auto data_float = reinterpret_cast(pb_tensor_float.mutable_data()); *data_float = 2.3; map_item["key_float"] = tensor_float; // scalar: key_str string value = "ut_test"; proto::Tensor tensor_str; ProtoTensor pb_tensor_str(&tensor_str); DataType type_str = kMSI_String; pb_tensor_str.set_data_type(type_str); pb_tensor_str.add_bytes_data(reinterpret_cast(value.data()), value.length()); map_item["key_str"] = tensor_str; // scalar: key_bytes string value_bytes = "ut_test"; proto::Tensor tensor_bytes; ProtoTensor pb_tensor_bytes(&tensor_bytes); DataType type_bytes = kMSI_Bytes; pb_tensor_bytes.set_data_type(type_bytes); pb_tensor_bytes.add_bytes_data(reinterpret_cast(value_bytes.data()), value_bytes.length()); map_item["key_bytes"] = tensor_bytes; // test tensor: auto instance_ptr2 = reply.add_instances(); auto &map_item2 = *(instance_ptr2->mutable_items()); // tensor int: vector tensor_value_int = {1, 2, 3}; proto::Tensor tensor_int2; ProtoTensor pb_tensor_int2(&tensor_int2); DataType type_int2 = kMSI_Int32; pb_tensor_int2.set_data_type(type_int2); pb_tensor_int2.set_shape({3}); pb_tensor_int2.resize_data(pb_tensor_int2.GetTypeSize(type_int2) * 3); for (int i = 0; i < 3; i++) { auto data_int2 = reinterpret_cast(pb_tensor_int2.mutable_data()) + i; *data_int2 = tensor_value_int[i]; } map_item2["key_int"] = tensor_int2; // tensor: key_bool vector> tensor_value_bool = {{true, false}, {false, true}}; proto::Tensor tensor_bool2; ProtoTensor pb_tensor_bool2(&tensor_bool2); DataType type_bool2 = kMSI_Bool; pb_tensor_bool2.set_data_type(type_bool2); pb_tensor_bool2.set_shape({2, 2}); pb_tensor_bool2.resize_data(pb_tensor_bool2.GetTypeSize(type_bool2) * 4); for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { auto data_bool2 = reinterpret_cast(pb_tensor_bool2.mutable_data()) + i * 2 + j; *data_bool2 = tensor_value_bool[i][j]; } } map_item2["key_bool"] = tensor_bool2; // tensor: key_float vector> tensor_value_float = {{1.1, 2.2}}; proto::Tensor tensor_float2; ProtoTensor pb_tensor_float2(&tensor_float2); DataType type_float2 = kMSI_Float32; pb_tensor_float2.set_data_type(type_float2); pb_tensor_float2.set_shape({1, 2}); pb_tensor_float2.resize_data(pb_tensor_float2.GetTypeSize(type_float2) * 2); for (int i = 0; i < 1; i++) { for (int j = 0; j < 2; j++) { auto data_float2 = reinterpret_cast(pb_tensor_float2.mutable_data()) + i * 1 + j; *data_float2 = tensor_value_float[i][j]; } } map_item2["key_float"] = tensor_float2; // tensor: key_str vector tensor_value_str = {"ut_test", "ut_test2"}; proto::Tensor tensor_str2; ProtoTensor pb_tensor_str2(&tensor_str2); DataType type_str2 = kMSI_String; pb_tensor_str2.set_data_type(type_str2); pb_tensor_str2.set_shape({2}); for (int i = 0; i < 2; i++) { pb_tensor_str2.add_bytes_data(reinterpret_cast(tensor_value_str[i].data()), tensor_value_str[i].length()); } map_item2["key_str"] = tensor_str2; Status status2 = restful_service.ParseReply(reply, &out_js); ASSERT_EQ(status2.StatusCode(), SUCCESS); string out_str = out_js.dump(); std::cout << "Parse reply out:" << out_str << std::endl; ASSERT_TRUE(out_js.is_object()); for (auto &item : out_js.items()) { ASSERT_EQ(item.key(), "instances"); ASSERT_TRUE(item.value().is_array()); ASSERT_EQ(item.value().size(), 2); int sum = 0; // array for (auto &element : item.value()) { ASSERT_TRUE(element.is_object()); if (element.size() == 5) { int count = 0; // object std::cout << "===start====" << std::endl; for (auto &it : element.items()) { if (it.key() == "key_int") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 1); auto array_items = it.value().items(); auto int_val = *(array_items.begin()); ASSERT_TRUE(int_val.value().is_number_integer()); ASSERT_EQ(int_val.value().get(), 1); count++; } else if (it.key() == "key_bool") { ASSERT_TRUE(it.value().is_boolean()); ASSERT_EQ(it.value().get(), false); count++; } else if (it.key() == "key_float") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 1); auto array_items = it.value().items(); auto float_val = *(array_items.begin()); ASSERT_FLOAT_EQ(float_val.value().get(), 2.3); count++; } else if (it.key() == "key_str") { ASSERT_TRUE(it.value().is_string()); ASSERT_EQ(it.value().get(), "ut_test"); count++; } else if (it.key() == "key_bytes") { ASSERT_TRUE(it.value().is_object()); ASSERT_EQ(it.value()["b64"].get(), "dXRfdGVzdA=="); count++; } } ASSERT_EQ(count, 5); sum++; } else if (element.size() == 4) { int count = 0; // object for (auto &it : element.items()) { if (it.key() == "key_int") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 3); ASSERT_EQ(it.value()[0].get(), 1); ASSERT_EQ(it.value()[1].get(), 2); ASSERT_EQ(it.value()[2].get(), 3); count++; } else if (it.key() == "key_bool") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 2); ASSERT_TRUE(it.value()[0].is_array()); ASSERT_EQ(it.value()[0].size(), 2); ASSERT_EQ(it.value()[0][0].get(), true); ASSERT_EQ(it.value()[0][1].get(), false); ASSERT_EQ(it.value()[1].size(), 2); ASSERT_EQ(it.value()[1][0].get(), false); ASSERT_EQ(it.value()[1][1].get(), true); count++; } else if (it.key() == "key_float") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 1); ASSERT_TRUE(it.value()[0].is_array()); ASSERT_EQ(it.value()[0].size(), 2); ASSERT_FLOAT_EQ(it.value()[0][0].get(), 1.1); ASSERT_FLOAT_EQ(it.value()[0][1].get(), 2.2); count++; } else if (it.key() == "key_str") { ASSERT_TRUE(it.value().is_array()); ASSERT_EQ(it.value().size(), 2); ASSERT_EQ(it.value()[0].get(), "ut_test"); ASSERT_EQ(it.value()[1].get(), "ut_test2"); count++; } } ASSERT_EQ(count, 4); sum++; } } ASSERT_EQ(sum, 2); } } TEST_F(TestParseReply, test_reply_instances_num_not_match_FAIL) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status(INVALID_INPUTS); status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); nlohmann::json out_js; proto::PredictReply reply; auto instance_ptr = reply.add_instances(); auto &map_item = *(instance_ptr->mutable_items()); // test scalar: // scalar:key_int proto::Tensor tensor_int; ProtoTensor pb_tensor_int(&tensor_int); DataType type_int = kMSI_Int32; pb_tensor_int.set_data_type(type_int); pb_tensor_int.set_shape({1}); pb_tensor_int.resize_data(pb_tensor_int.GetTypeSize(type_int)); auto data_int = reinterpret_cast(pb_tensor_int.mutable_data()); *data_int = 1; map_item["key_int"] = tensor_int; // scalar: key_bool proto::Tensor tensor_bool; ProtoTensor pb_tensor_bool(&tensor_bool); DataType type_bool = kMSI_Bool; pb_tensor_bool.set_data_type(type_bool); pb_tensor_bool.resize_data(pb_tensor_bool.GetTypeSize(type_bool)); auto data_bool = reinterpret_cast(pb_tensor_bool.mutable_data()); *data_bool = false; map_item["key_bool"] = tensor_bool; // scalar: key_float proto::Tensor tensor_float; ProtoTensor pb_tensor_float(&tensor_float); DataType type_float = kMSI_Float32; pb_tensor_float.set_data_type(type_float); pb_tensor_float.set_shape({1}); pb_tensor_float.resize_data(pb_tensor_float.GetTypeSize(type_float)); auto data_float = reinterpret_cast(pb_tensor_float.mutable_data()); *data_float = 2.3; map_item["key_float"] = tensor_float; // scalar: key_str string value = "ut_test"; proto::Tensor tensor_str; ProtoTensor pb_tensor_str(&tensor_str); DataType type_str = kMSI_String; pb_tensor_str.set_data_type(type_str); pb_tensor_str.add_bytes_data(reinterpret_cast(value.data()), value.length()); map_item["key_str"] = tensor_str; // scalar: key_bytes string value_bytes = "ut_test"; proto::Tensor tensor_bytes; ProtoTensor pb_tensor_bytes(&tensor_bytes); DataType type_bytes = kMSI_Bytes; pb_tensor_bytes.set_data_type(type_bytes); pb_tensor_bytes.add_bytes_data(reinterpret_cast(value_bytes.data()), value_bytes.length()); map_item["key_bytes"] = tensor_bytes; // test tensor: auto instance_ptr2 = reply.add_instances(); auto &map_item2 = *(instance_ptr2->mutable_items()); // tensor int: vector tensor_value_int = {1, 2, 3}; proto::Tensor tensor_int2; ProtoTensor pb_tensor_int2(&tensor_int2); DataType type_int2 = kMSI_Int32; pb_tensor_int2.set_data_type(type_int2); pb_tensor_int2.set_shape({3}); pb_tensor_int2.resize_data(pb_tensor_int2.GetTypeSize(type_int2) * 3); for (int i = 0; i < 3; i++) { auto data_int2 = reinterpret_cast(pb_tensor_int2.mutable_data()) + i; *data_int2 = tensor_value_int[i]; } map_item2["key_int"] = tensor_int2; // tensor: key_bool vector> tensor_value_bool = {{true, false}, {false, true}}; proto::Tensor tensor_bool2; ProtoTensor pb_tensor_bool2(&tensor_bool2); DataType type_bool2 = kMSI_Bool; pb_tensor_bool2.set_data_type(type_bool2); pb_tensor_bool2.set_shape({2, 2}); pb_tensor_bool2.resize_data(pb_tensor_bool2.GetTypeSize(type_bool2) * 4); for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { auto data_bool2 = reinterpret_cast(pb_tensor_bool2.mutable_data()) + i * 2 + j; *data_bool2 = tensor_value_bool[i][j]; } } map_item2["key_bool"] = tensor_bool2; // tensor: key_float vector> tensor_value_float = {{1.1, 2.2}}; proto::Tensor tensor_float2; ProtoTensor pb_tensor_float2(&tensor_float2); DataType type_float2 = kMSI_Float32; pb_tensor_float2.set_data_type(type_float2); pb_tensor_float2.set_shape({1, 2}); pb_tensor_float2.resize_data(pb_tensor_float2.GetTypeSize(type_float2) * 2); for (int i = 0; i < 1; i++) { for (int j = 0; j < 2; j++) { auto data_float2 = reinterpret_cast(pb_tensor_float2.mutable_data()) + i * 1 + j; *data_float2 = tensor_value_float[i][j]; } } map_item2["key_float"] = tensor_float2; // tensor: key_str vector tensor_value_str = {"ut_test", "ut_test2"}; proto::Tensor tensor_str2; ProtoTensor pb_tensor_str2(&tensor_str2); DataType type_str2 = kMSI_String; pb_tensor_str2.set_data_type(type_str2); pb_tensor_str2.set_shape({2}); for (int i = 0; i < 2; i++) { pb_tensor_str2.add_bytes_data(reinterpret_cast(tensor_value_str[i].data()), tensor_value_str[i].length()); } map_item2["key_str"] = tensor_str2; Status status2 = restful_service.ParseReply(reply, &out_js); ASSERT_NE(status2.StatusCode(), SUCCESS); } TEST_F(TestParseReply, test_reply_error_num_not_match_FAIL) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status(INVALID_INPUTS); status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); nlohmann::json out_js; proto::PredictReply reply; auto error_msg = reply.add_error_msg(); error_msg->set_error_msg("error1"); auto error_msg2 = reply.add_error_msg(); error_msg2->set_error_msg("error2"); Status status2 = restful_service.ParseReply(reply, &out_js); ASSERT_NE(status2.StatusCode(), SUCCESS); } TEST_F(TestParseReply, test_reply_type_not_set_FAIL) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status(INVALID_INPUTS); status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); nlohmann::json out_js; proto::PredictReply reply; auto instance_ptr = reply.add_instances(); auto &map_item = *(instance_ptr->mutable_items()); // test scalar: // scalar:key_int proto::Tensor tensor_int; ProtoTensor pb_tensor_int(&tensor_int); pb_tensor_int.set_shape({1}); pb_tensor_int.resize_data(pb_tensor_int.GetTypeSize(kMSI_Int32)); auto data_int = reinterpret_cast(pb_tensor_int.mutable_data()); *data_int = 1; map_item["key_int"] = tensor_int; Status status2 = restful_service.ParseReply(reply, &out_js); ASSERT_NE(status2.StatusCode(), SUCCESS); } TEST_F(TestParseReply, test_reply_type_fp16_SUCCESS) { nlohmann::json js = R"( {"instances":[ { "key_tag":"scalar", "key_int": 1, "key_bool": false, "key_float": 2.3, "key_str": "ut_test", "key_bytes": {"b64": "dXRfdGVzdA==", "type": "bytes"} } ] } )"_json; struct evhttp_request request_local = {}; struct evhttp_request *request = &request_local; int size = 100; std::shared_ptr request_msg = std::make_shared(request, size); request_msg->request_message_ = js; std::shared_ptr restful_request = std::make_shared(request_msg); proto::PredictRequest predict_request; RestfulService restful_service; Status status(INVALID_INPUTS); status = restful_service.ParseRequest(restful_request, &predict_request); ASSERT_EQ(status.StatusCode(), SUCCESS); nlohmann::json out_js; proto::PredictReply reply; auto instance_ptr = reply.add_instances(); auto &map_item = *(instance_ptr->mutable_items()); // test scalar: // scalar: key_float proto::Tensor tensor_float; ProtoTensor pb_tensor_float(&tensor_float); DataType type_float = kMSI_Float16; pb_tensor_float.set_data_type(type_float); pb_tensor_float.set_shape({1}); pb_tensor_float.resize_data(pb_tensor_float.GetTypeSize(type_float)); map_item["key_float16"] = tensor_float; Status status2 = restful_service.ParseReply(reply, &out_js); ASSERT_EQ(status2.StatusCode(), SUCCESS); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_shared_memory.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/ut/cpp/common/test_servable_common.h" #include "common/shared_memory.h" #define private public #undef private using std::string; using std::vector; namespace mindspore { namespace serving { class TestSharedMemory : public UT::Common { public: void SetUp() override { UT::Common::SetUp(); } void TearDown() override { UT::Common::TearDown(); } }; TEST_F(TestSharedMemory, test_alloc_release_shared_memory_success) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); std::string first_memory_key; std::vector first_shm_list; for (int i = 0; i < 3; i++) { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); ASSERT_EQ(shm_item.memory_key_prefix, memory_key_prefix); ASSERT_EQ(shm_item.size, item_size); ASSERT_TRUE(shm_item.memory_key.find(memory_key_prefix) != std::string::npos); if (first_memory_key.empty()) { first_memory_key = shm_item.memory_key; } else { ASSERT_EQ(first_memory_key, shm_item.memory_key); } first_shm_list.push_back(shm_item); } // new shared memory std::string second_memory_key; std::vector second_shm_list; for (int i = 0; i < 3; i++) { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); ASSERT_EQ(shm_item.memory_key_prefix, memory_key_prefix); ASSERT_EQ(shm_item.size, item_size); ASSERT_TRUE(shm_item.memory_key.find(memory_key_prefix) != std::string::npos); if (second_memory_key.empty()) { second_memory_key = shm_item.memory_key; } else { ASSERT_EQ(second_memory_key, shm_item.memory_key); } ASSERT_NE(second_memory_key, first_memory_key); second_shm_list.push_back(shm_item); } // free shared memory and alloc { auto &free_memory = second_shm_list[1]; allocator.ReleaseMemoryItem(free_memory); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); ASSERT_EQ(shm_item.memory_key, free_memory.memory_key); ASSERT_EQ(shm_item.bytes_size, free_memory.bytes_size); ASSERT_EQ(shm_item.offset_address, free_memory.offset_address); ASSERT_EQ(shm_item.offset, free_memory.offset); } { auto &free_memory = first_shm_list[1]; allocator.ReleaseMemoryItem(free_memory); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); ASSERT_EQ(shm_item.memory_key, free_memory.memory_key); ASSERT_EQ(shm_item.bytes_size, free_memory.bytes_size); ASSERT_EQ(shm_item.offset_address, free_memory.offset_address); ASSERT_EQ(shm_item.offset, free_memory.offset); } } TEST_F(TestSharedMemory, test_alloc_release_shared_memory_repeat_release_failed) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); allocator.ReleaseMemoryItem(shm_item); try { allocator.ReleaseMemoryItem(shm_item); FAIL(); } catch (std::runtime_error &ex) { std::string error_msg = ex.what(); auto index = error_msg.find("Shared memory " + shm_item.memory_key + " has already been in free set, offset: "); ASSERT_TRUE(index != std::string::npos); } } TEST_F(TestSharedMemory, test_alloc_attach_shared_memory_success) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); ASSERT_NE(shm_item.offset_address, attach_item.offset_address); attach_item.offset_address[0] = 0xfe; ASSERT_EQ(0xfe, shm_item.offset_address[0]); shm_item.offset_address[1] = 0xfa; ASSERT_EQ(0xfa, attach_item.offset_address[1]); attach.Detach(attach_item.memory_key); } TEST_F(TestSharedMemory, test_alloc_twice_attach_shared_memory_success) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; std::string memory_key; // first memory item { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); ASSERT_NE(shm_item.offset_address, attach_item.offset_address); attach_item.offset_address[0] = 0xfe; ASSERT_EQ(0xfe, shm_item.offset_address[0]); shm_item.offset_address[1] = 0xfa; ASSERT_EQ(0xfa, attach_item.offset_address[1]); memory_key = shm_item.memory_key; } // second memory item { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); ASSERT_NE(shm_item.offset_address, attach_item.offset_address); attach_item.offset_address[3] = 0xfe; ASSERT_EQ(0xfe, shm_item.offset_address[3]); shm_item.offset_address[4] = 0xfa; ASSERT_EQ(0xfa, attach_item.offset_address[4]); } attach.Detach(memory_key); } TEST_F(TestSharedMemory, test_alloc_re_attach_shared_memory_success) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; // first memory item { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); ASSERT_NE(shm_item.offset_address, attach_item.offset_address); attach_item.offset_address[0] = 0xfe; ASSERT_EQ(0xfe, shm_item.offset_address[0]); shm_item.offset_address[1] = 0xfa; ASSERT_EQ(0xfa, attach_item.offset_address[1]); attach.Detach(shm_item.memory_key); } // second memory item { SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); ASSERT_NE(shm_item.offset_address, attach_item.offset_address); attach_item.offset_address[3] = 0xfe; ASSERT_EQ(0xfe, shm_item.offset_address[3]); shm_item.offset_address[4] = 0xfa; ASSERT_EQ(0xfa, attach_item.offset_address[4]); attach.Detach(shm_item.memory_key); } } TEST_F(TestSharedMemory, test_alloc_attach_shared_memory_attach_repeat_success) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryAttachItem attach_item2; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item2); ASSERT_TRUE(status == SUCCESS); ASSERT_EQ(attach_item.offset_address, attach_item2.offset_address); } TEST_F(TestSharedMemory, test_alloc_attach_shared_memory_detach_repeat_failed) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 3); ASSERT_TRUE(status == SUCCESS); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; SharedMemoryAttachItem attach_item; status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); status = attach.Detach(shm_item.memory_key); ASSERT_TRUE(status == SUCCESS); status = attach.Detach(shm_item.memory_key); ASSERT_TRUE(status != SUCCESS); } TEST_F(TestSharedMemory, test_alloc_attach_invalid_shared_memory_failed) { SharedMemoryAllocator allocator; std::string memory_key_prefix = "test_memory_key"; uint64_t item_size = 64; auto status = allocator.NewMemoryBuffer(memory_key_prefix, item_size, 1); ASSERT_TRUE(status == SUCCESS); SharedMemoryItem shm_item; status = allocator.AllocMemoryItem(memory_key_prefix, &shm_item); ASSERT_TRUE(status == SUCCESS); SharedMemoryManager attach; SharedMemoryAttachItem attach_item; // invalid memory key status = attach.Attach("invalid memory key", shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status != SUCCESS); // invalid memory bytes size status = attach.Attach(shm_item.memory_key, 0, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status != SUCCESS); // invalid memory data offset status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.bytes_size, shm_item.size, &attach_item); ASSERT_TRUE(status != SUCCESS); // invalid memory data size status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, 0, shm_item.bytes_size + 1, &attach_item); ASSERT_TRUE(status != SUCCESS); // success status = attach.Attach(shm_item.memory_key, shm_item.bytes_size, shm_item.offset, shm_item.size, &attach_item); ASSERT_TRUE(status == SUCCESS); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_start_preprocess_postprocess.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/ut/cpp/common/test_servable_common.h" namespace mindspore { namespace serving { class TestPreprocessPostprocess : public TestMasterWorkerClient { public: TestPreprocessPostprocess() = default; ~TestPreprocessPostprocess() = default; virtual void SetUp() {} virtual void TearDown() { TestMasterWorkerClient::TearDown(); } MethodSignature InitDefaultMethod() { MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); return method_signature; } MethodSignature InitMethodSig() { MethodSignature method_signature; method_signature.servable_name = "test_servable"; method_signature.method_name = "add_cast"; method_signature.inputs = {"x1", "x2"}; method_signature.outputs = {"y"}; return method_signature; } const std::string model_file_ = "test_add.mindir"; }; TEST_F(TestPreprocessPostprocess, test_master_worker_with_preproces_and_postprocess_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable DeclareServable("test_servable", "test_add.mindir", "mindir", false); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitDefaultMethod(); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input int32 --> preprocess int32-float32 --> servable float32-float32 --> postprocess int32-int32, shape [2,2] auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestPreprocessPostprocess, test_master_worker_with_preproces_and_postprocess_batching_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable // with_batch_dim = true DeclareServable("test_servable", "test_add.mindir", "mindir", true); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitDefaultMethod(); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input int32 --> preprocess int32-float32 --> servable float32-float32 --> postprocess int32-int32, shape [2] auto y_data_list = InitMultiInstancesShape2Request(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestPreprocessPostprocess, test_master_worker_with_only_preproces_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable DeclareServable("test_servable", "test_add.mindir", "mindir", false); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // servable output as method output method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input int32 --> preprocess int32-float32 --> servable float32-float32, shape [2,2] auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestPreprocessPostprocess, test_master_worker_with_only_preproces_batching_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable // with_batch_dim=true DeclareServable("test_servable", "test_add.mindir", "mindir", true); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input int32 --> preprocess int32-float32 --> servable float32-float32, shape [2] auto y_data_list = InitMultiInstancesShape2Request(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); ASSERT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestPreprocessPostprocess, test_master_worker_with_only_postprocess_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable DeclareServable("test_servable", "test_add.mindir", "mindir", false); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{0, 0}, {0, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{1, 0}}); // servable output as method output method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32 --> servable float32-float32 --> postprocess float32-int32, shape [2,2] auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } TEST_F(TestPreprocessPostprocess, test_master_worker_with_only_postprocess_batching_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // declare_servable // with_batch_dim=true DeclareServable("test_servable", "test_add.mindir", "mindir", true); // register method ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{0, 0}, {0, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{1, 0}}); // servable output as method output method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32 --> servable float32-float32 --> postprocess float32-int32, shape [2] auto y_data_list = InitMultiInstancesShape2Request(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output CheckMultiInstanceResult(reply, y_data_list, instances_count); } // Test data flow in input\preprocess\predict\postprocess TEST_F(TestPreprocessPostprocess, test_worker_start_preprocess_not_found) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); try { MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("preprocess_fake_fun", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); FAIL(); } catch (std::runtime_error &ex) { ExpectContainMsg(ex.what(), "Function 'preprocess_fake_fun' is not defined") } } TEST_F(TestPreprocessPostprocess, test_worker_start_postprocess_not_found) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); try { MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("postprocess_fake_fun", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); FAIL(); } catch (std::runtime_error &ex) { ExpectContainMsg(ex.what(), "Function 'postprocess_fake_fun' is not defined") } } TEST_F(TestPreprocessPostprocess, test_preproces_process_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitDefaultMethod(); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input float32, invalid for preprocess, which required int32 auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), instances_count); ExpectContainMsg(reply.error_msg(0).error_msg(), "Call failed: Input data type invalid"); } TEST_F(TestPreprocessPostprocess, test_postproces_process_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{0, 0}}); // use method input as postprocess input // servable output as method output method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); // run servable proto::PredictRequest request; size_t instances_count = 3; // input int32, invalid for postprocess auto y_data_list = InitMultiInstancesRequest(&request, servable_name_, "add_cast", 0, instances_count); proto::PredictReply reply; auto grpc_status = Dispatch(request, &reply); EXPECT_TRUE(grpc_status.ok()); // checkout output ASSERT_EQ(reply.error_msg_size(), instances_count); ExpectContainMsg(reply.error_msg(0).error_msg(), "Postprocess failed: Input data type invalid"); } TEST_F(TestPreprocessPostprocess, test_preproces_input_invalid1_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{1, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 0th input data of stage 1 cannot not come from stage 1"); } TEST_F(TestPreprocessPostprocess, test_preproces_input_invalid2_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {2, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 1th input data of stage 1 cannot not come from stage 2"); } TEST_F(TestPreprocessPostprocess, test_preproces_input_invalid3_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {3, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 1th input data of stage 1 cannot not come from stage 3"); } TEST_F(TestPreprocessPostprocess, test_preproces_input_invalid4_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 2}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage 1 1th input uses method 2th input, that is greater than the method inputs size 2"); } TEST_F(TestPreprocessPostprocess, test_predict_input_invalid1_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{2, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 0th input data of stage 2 cannot not come from stage 2"); } TEST_F(TestPreprocessPostprocess, test_predict_input_invalid2_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {3, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 1th input data of stage 2 cannot not come from stage 3"); } TEST_F(TestPreprocessPostprocess, test_predict_input_invalid3_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 2}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 2 1th input uses c++ function stub_preprocess_cast_int32_to_fp32_cpp " "2th output, that is greater than the function output size 2"); } TEST_F(TestPreprocessPostprocess, test_predict_input_invalid4_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{0, 2}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage 2 0th input uses method 2th input, that is greater than the method inputs size 2"); } TEST_F(TestPreprocessPostprocess, test_postprocess_input_invalid1_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{3, 0}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The 0th input data of stage 3 cannot not come from stage 3"); } TEST_F(TestPreprocessPostprocess, test_postprocess_input_invalid2_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{0, 2}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage 3 0th input uses method 2th input, that is greater than the method inputs size 2"); } TEST_F(TestPreprocessPostprocess, test_postprocess_input_invalid3_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{1, 2}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 3 0th input uses c++ function stub_preprocess_cast_int32_to_fp32_cpp" " 2th output, that is greater than the function output size 2"); } TEST_F(TestPreprocessPostprocess, test_postprocess_input_invalid4_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 1}}); // servable output as method output method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 3 0th input uses model " "test_add.mindir subgraph 0 1th output, that is greater than the model output size 1"); } TEST_F(TestPreprocessPostprocess, test_return_invalid1_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{0, 2}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage 4 0th input uses method 2th input, " "that is greater than the method inputs size 2"); } TEST_F(TestPreprocessPostprocess, test_return_invalid2_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{1, 2}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 4 0th input uses c++ function stub_preprocess_cast_int32_to_fp32_cpp" " 2th output, that is greater than the function output size 2"); } TEST_F(TestPreprocessPostprocess, test_return_invalid3_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel(model_file_, {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{2, 1}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 4 0th input uses model " "test_add.mindir subgraph 0 1th output, that is greater than the model output size 1"); } TEST_F(TestPreprocessPostprocess, test_return_invalid4_failed) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", false); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature = InitMethodSig(); // preprocess method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 1}}); // method input 0 and input 1 as servable input method_signature.AddStageModel("test_add.mindir", {{1, 0}, {1, 1}}); // postprocess method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // servable output as method output method_signature.SetReturn({{3, 1}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The stage(begin with 1) 4 0th input uses c++ function stub_postprocess_cast_fp32_to_int32_cpp" " 1th output, that is greater than the function output size 1"); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/cpp/tests/test_start_worker.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/ut/cpp/common/test_servable_common.h" namespace mindspore { namespace serving { class TestStartWorker : public TestMasterWorker { public: TestStartWorker() = default; ~TestStartWorker() = default; virtual void SetUp() {} virtual void TearDown() { TestMasterWorker::TearDown(); } }; TEST_F(TestStartWorker, test_worker_start_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); } TEST_F(TestStartWorker, test_worker_start_error_model_file_name) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add_error.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable auto status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Load model failed, servable directory: "); } TEST_F(TestStartWorker, test_worker_start_error_version_number) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable int error_version_number = 2; auto status = StartServable("test_servable_dir", "test_servable", error_version_number); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg( status.StatusMessage(), "Start servable failed: There is no specified version directory of models, specified version number: 2"); } TEST_F(TestStartWorker, test_worker_start_multi_version_number) { auto servable_dir = std::string(test_info_->test_case_name()) + "_test_servable_dir"; Init(servable_dir, "test_servable", 1, "test_add.mindir"); Init(servable_dir, "test_servable", 2, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable int version_number = 2; Status status = StartServable(servable_dir, "test_servable", version_number); EXPECT_TRUE(status.IsSuccess()); } TEST_F(TestStartWorker, test_worker_start_version_number_no_valid) { auto servable_dir = std::string(test_info_->test_case_name()) + "_test_servable_dir"; Init(servable_dir, "test_servable", 0, "test_add.mindir"); Init(servable_dir, "test_servable", -2, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable Status status = StartServable(servable_dir, "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg( status.StatusMessage(), "Start servable failed: There is no specified version directory of models, specified version number: 1"); } TEST_F(TestStartWorker, test_worker_start_error_servable_dir) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable std::string error_servable_dir = "test_servable_dir_error"; Status status = StartServable(error_servable_dir, "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg( status.StatusMessage(), "Start servable failed: There is no specified version directory of models, specified version number: 0"); } TEST_F(TestStartWorker, test_worker_start_error_servable_name) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable std::string error_servable_name = "test_servable_error"; Status status = StartServable("test_servable_dir", error_servable_name, 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "'test_servable_error' has not been registered"); } TEST_F(TestStartWorker, test_worker_start_error_servable_format) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "om", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Not support device type Ascend and model type OM. "); } TEST_F(TestStartWorker, test_worker_start_no_registered_method) { Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); // no registered method // RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 2); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "There is no method registered for servable"); } TEST_F(TestStartWorker, test_worker_start_no_declared_servable) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); // no declared method // DeclareServable("test_servable", "test_add.mindir", "mindir", true); auto status = RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "RegisterInputOutputInfo failed, cannot find model test_add.mindir"); } TEST_F(TestStartWorker, test_worker_start_multi_method) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, 1); RegisterMethod("test_servable", "test_add.mindir", "add_common2", {"x1", "x2"}, {"y"}, 2, 1); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); } TEST_F(TestStartWorker, test_worker_start_method_servable_input_count_not_match) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); size_t servable_input_count = 1; RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, servable_input_count, 1); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The inputs count 1 in register_method not equal to the count 2 defined in model") } TEST_F(TestStartWorker, test_worker_start_method_servable_output_count_not_match) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); size_t servable_output_count = 2; RegisterMethod("test_servable", "test_add.mindir", "add_common", {"x1", "x2"}, {"y"}, 2, servable_output_count); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The outputs count 2 in register_method not equal to the count 1 defined in model") } // Test data flow in input\preprocess\predict\postprocess TEST_F(TestStartWorker, test_worker_start_preprocess_not_found) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature; method_signature.servable_name = "test_servable"; method_signature.method_name = "add_common"; method_signature.inputs = {"x1", "x2"}; method_signature.outputs = {"y"}; // preprocess try { method_signature.AddStageFunction("preprocess_fake_fun", {{0, 0}, {0, 0}}); // method input 0 and input 1 as servable input method_signature.AddStageModel("test_add.mindir", {{1, 0}, {0, 1}}, 0, ""); // servable output as method output method_signature.SetReturn({{2, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); } catch (std::runtime_error &ex) { ExpectContainMsg(ex.what(), "Function 'preprocess_fake_fun' is not defined") } } TEST_F(TestStartWorker, test_worker_start_with_preproces_and_postprocess_success) { Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); ServableRegister::Instance().RegisterInputOutputInfo("test_add.mindir", 2, 1); MethodSignature method_signature; method_signature.servable_name = "test_servable"; method_signature.method_name = "add_cast"; method_signature.inputs = {"x1", "x2"}; method_signature.outputs = {"y"}; // preprocess, stage 1, input is input data(stage index = 0) 0 and 1 method_signature.AddStageFunction("stub_preprocess_cast_int32_to_fp32_cpp", {{0, 0}, {0, 0}}); // model, stage 2, input is stage 1 output data 0 and 1 method_signature.AddStageModel("test_add.mindir", {{1, 0}, {1, 1}}, 0); // postprocess, stage 3, input is stage 2 output data 0 and 1 method_signature.AddStageFunction("stub_postprocess_cast_fp32_to_int32_cpp", {{2, 0}}); // method output, stage 3 output data 0 method_signature.SetReturn({{3, 0}}); ServableRegister::Instance().RegisterMethod(method_signature); // start_servable Status status = StartServable("test_servable_dir", "test_servable", 1); EXPECT_TRUE(status.IsSuccess()); } } // namespace serving } // namespace mindspore ================================================ FILE: tests/ut/python/CMakeLists.txt ================================================ set(STUB_DIR ../stub) set(ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..) file(GLOB_RECURSE UT_SERVING_STUB RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${STUB_DIR}/cxx_api/*.cc" "${STUB_DIR}/graph_impl_stub.cc" "${STUB_DIR}/include/utils/*.cc") add_library(mindspore SHARED ${UT_SERVING_STUB}) set(UT_SERVING_COMMON ${UT_SERVING_CORE_SRC} ${UT_SERVING_STUB}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${STUB_DIR}/..) include_directories(${STUB_DIR}) include_directories(${STUB_DIR}/include) include_directories(${ROOT_DIR}/third_party) link_directories(${CMKAE_BINARY_DIR}/securec/src) target_link_libraries(mindspore PRIVATE ${SECUREC_LIBRARY} pthread) target_link_libraries(mindspore PRIVATE mindspore_serving::glog) set(LIBRARY_OUTPUT_PATH ${ROOT_DIR}/build/package/tests/mindspore/lib/) # copy mindspore include file(COPY ${STUB_DIR}/include/api DESTINATION ${ROOT_DIR}/build/package/tests/mindspore/include) ================================================ FILE: tests/ut/python/mindspore/dataset/__init__.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ ================================================ FILE: tests/ut/python/runtest.sh ================================================ #!/bin/bash # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e BASEPATH=$( cd "$(dirname "$0")" pwd ) PROJECT_PATH=${BASEPATH}/../../.. rm -rf ${PROJECT_PATH}/build/package/mindspore_serving/server rm -rf ${PROJECT_PATH}/build/package/mindspore_serving/client cp -r ${PROJECT_PATH}/mindspore_serving/server ${PROJECT_PATH}/build/package/mindspore_serving/ cp -r ${PROJECT_PATH}/mindspore_serving/client ${PROJECT_PATH}/build/package/mindspore_serving/ export PYTHONPATH=${PROJECT_PATH}/build/package:${PROJECT_PATH}/tests/ut/python:$PYTHONPATH export LD_LIBRARY_PATH=${PROJECT_PATH}/build/package/tests/mindspore/lib:${LD_LIBRARY_PATH} echo "PYTHONPATH=$PYTHONPATH" echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" export GLOG_v=1 unset http_proxy unset https_proxy function clear_port() { PROCESS=`netstat -nlp | grep :$1 | awk '{print $7}' | awk -F"/" '{print $1}'` for i in $PROCESS do echo "Kill the process [ $i ]" kill -9 $i done } port_list=(5500 6200 7000 7001 7002 7003 7004 7005 7006 7007) for port in ${port_list[*]}; do clear_port ${port} done cd ${PROJECT_PATH}/tests/ut/python/tests/ if [ $# -gt 0 ]; then pytest -s -v . -k "$1" else pytest -v . fi rm -f *.crt *.csr *.key *.srl rm -rf unix_socket_files exit $? ================================================ FILE: tests/ut/python/servable_config/add_servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """add model servable config""" import numpy as np from mindspore_serving.server import register def add_trans_datatype(x1, x2): """define preprocess, this example has one input and one output""" return x1.astype(np.float32), x2.astype(np.float32) # when with_batch_dim is set to False, only 2x2 add is supported # when with_batch_dim is set to True(default), Nx2 add is supported, while N is viewed as batch # float32 inputs/outputs model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) # register add_common method in add @register.register_method(output_names=["y"]) def add_common(x1, x2): # only support float32 inputs """method add_common data flow definition, only call model servable""" y = register.add_stage(model, x1, x2, outputs_count=1) return y # register add_cast method in add @register.register_method(output_names=["y"]) def add_cast(x1, x2): """method add_cast data flow definition, only call preprocess and model servable""" x1, x2 = register.add_stage(add_trans_datatype, x1, x2, outputs_count=2) # cast input to float32 y = register.add_stage(model, x1, x2, outputs_count=1) return y ================================================ FILE: tests/ut/python/servable_config/generate_certs.sh ================================================ #!/bin/bash echo "[req] default_bits = 2048 distinguished_name = req_distinguished_name x509_extensions = v3_req prompt = no [req_distinguished_name] countryName = XX stateOrProvinceName = Self-signed Cert commonName = Self-signed Cert [v3_req] basicConstraints = CA:TRUE" > ca.cnf # generate ca's cert and private key for signing server and client cert openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout ca.key -out ca.crt -config ca.cnf rm ca.cnf # generate server's cert IP=$SERVING_IP DNS=$SERVING_HOSTNAME CN=$SERVING_COMMON_NAME echo " authorityKeyIdentifier=keyid,issuer basicConstraints=CA:FALSE keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment subjectAltName = @alt_names [alt_names] IP.1 = $IP DNS.1 = $DNS " > server.cnf openssl genrsa -out server.key 2048 openssl req -new -key server.key -out server.csr -subj "/C=XX/ST=MyST/L=XX/O=HW/OU=gRPC/CN=$CN" openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out server.crt -days 730 -sha256 -extfile server.cnf rm server.cnf # generate client's cert openssl genrsa -out client.key 2048 openssl req -new -key client.key -out client.csr -subj "/C=XX/ST=MyST/L=XX/O=HW/OU=gRPC/CN=client" openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out client.crt -days 730 -sha256 ================================================ FILE: tests/ut/python/tests/common.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving, Common""" import os from functools import wraps from mindspore_serving import server from mindspore_serving import log as logger from mindspore_serving.client import Client servable_index = 0 class ServingTestBase: def __init__(self): servable_dir = "serving_python_ut_servables" self.servable_dir = os.path.join(os.getcwd(), servable_dir) os.system(f"rm -rf {self.servable_dir}") global servable_index self.servable_name = "add_" + str(servable_index) servable_index += 1 def init_servable(self, version_number, config_file, model_file="tensor_add.mindir"): cur_dir = os.path.dirname(os.path.abspath(__file__)) config_file_abs = os.path.join(os.path.join(cur_dir, "../servable_config/"), config_file) try: with open(config_file_abs, "r") as fp: servable_config_content = fp.read() except FileNotFoundError: servable_config_content = None self.init_servable_with_servable_config(version_number, servable_config_content, model_file) def init_servable_with_servable_config(self, version_number, servable_config_content, model_file="tensor_add.mindir", model_config_file=None): if not isinstance(model_file, (tuple, list)): model_file = (model_file,) self.version_number = version_number self.model_files = model_file self.servable_name_path = os.path.join(self.servable_dir, self.servable_name) self.version_number_path = os.path.join(self.servable_name_path, str(version_number)) self.model_files_path = [os.path.join(self.version_number_path, file) for file in model_file] try: os.mkdir(self.servable_dir) except FileExistsError: pass try: os.mkdir(self.servable_name_path) except FileExistsError: pass if self.model_files_path and version_number is not None: try: os.mkdir(self.version_number_path) except FileExistsError: pass for file in self.model_files_path: with open(file, "w") as fp: print("model content", file=fp) if servable_config_content is not None: config_file = os.path.join(self.servable_name_path, "servable_config.py") with open(config_file, "w") as fp: fp.write(servable_config_content) if model_config_file is not None: model_config_file_path = os.path.join(self.servable_name_path, model_config_file) with open(model_config_file_path, "w") as fp: print("model config file", file=fp) def init_distributed_servable(self, servable_config_content, rank_size, rank_table_content): self.version_number = 1 self.servable_name_path = os.path.join(self.servable_dir, self.servable_name) self.model_dir = os.path.join(self.servable_dir, "model_" + self.servable_name) self.rank_table_content_path = os.path.join(self.servable_dir, self.servable_name + "_hccl.json") try: os.mkdir(self.servable_dir) except FileExistsError: pass try: os.mkdir(self.servable_name_path) except FileExistsError: pass try: os.mkdir(self.model_dir) except FileExistsError: pass self.model_file_list = [] for i in range(rank_size): model_file_path = os.path.join(self.model_dir, f"model{i}.mindir") self.model_file_list.append(model_file_path) with open(model_file_path, "w") as fp: print("model content", file=fp) self.group_config_list = [] for i in range(rank_size): group_config = os.path.join(self.model_dir, f"group{i}.pb") self.group_config_list.append(group_config) with open(group_config, "w") as fp: print("group config content", file=fp) if servable_config_content is not None: config_file = os.path.join(self.servable_name_path, "servable_config.py") with open(config_file, "w") as fp: fp.write(servable_config_content) if rank_table_content is not None: with open(self.rank_table_content_path, "w") as fp: fp.write(rank_table_content) @staticmethod def add_on_exit(fun): global exit_fun_list exit_fun_list.append(fun) exit_fun_list = [] client_create_list = [] def serving_test(func): @wraps(func) def wrap_test(*args, **kwargs): try: os.environ["SERVING_ENABLE_CPU_DEVICE"] = "0" os.environ["SERVING_ENABLE_GPU_DEVICE"] = "0" func(*args, **kwargs) except Exception: logger.error("Serving test catch exception") serving_logs_dir = os.path.join(os.getcwd(), "serving_logs") os.system(f"ls -l {serving_logs_dir}/*.log && cat {serving_logs_dir}/*.log") raise finally: logger.info("Serving test begin to clear") server.master.context.set_max_enqueued_requests(10000) server.stop() global client_create_list for client in client_create_list: del client.stub client.stub = None client_create_list = [] global exit_fun_list for fun in exit_fun_list: fun() exit_fun_list = [] cwd_dir = os.getcwd() servable_dir = os.path.join(cwd_dir, "serving_python_ut_servables") os.system(f"rm -rf {servable_dir}") temp_rank_dir = os.path.join(cwd_dir, "temp_rank_table") os.system(f"rm -rf {temp_rank_dir}") serving_logs_dir = os.path.join(cwd_dir, "serving_logs") os.system(f"rm -rf {serving_logs_dir}") unix_socket_files_dir = os.path.join(cwd_dir, "unix_socket_files") os.system(f"rm -rf {unix_socket_files_dir}") unix_socket_files_dir = os.path.join(cwd_dir, "device_") os.system(f"rm -rf {unix_socket_files_dir}*") os.system(f"rm -rf *.crt *.key *.csr *.srl") logger.info("Serving test end clear") return wrap_test def create_client(address, servable_name, method_name, version_number=0, ssl_config=None): client = Client(address, servable_name, method_name, version_number, ssl_config) client_create_list.append(client) return client def generate_cert(server_ip="0.0.0.0", server_host_name="serving", common_name="serving.com"): cur_dir = os.path.dirname(os.path.abspath(__file__)) shell_path = os.path.join(os.path.join(cur_dir, "../servable_config/"), "generate_certs.sh") os.environ["SERVING_IP"] = server_ip os.environ["SERVING_HOSTNAME"] = server_host_name os.environ["SERVING_COMMON_NAME"] = common_name with open(shell_path, 'r') as f: command = f.read() os.system(command) def release_client(client): del client.stub client.stub = None # test servable_config.py with client servable_config_import = r""" import numpy as np from mindspore_serving.server import register """ servable_config_declare_servable = r""" register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) """ servable_config_preprocess_cast = r""" def add_trans_datatype(x1, x2): return x1.astype(np.float32), x2.astype(np.float32) """ servable_config_method_add_common = r""" @register.register_method(output_names=["y"]) def add_common(x1, x2): # only support float32 inputs y = register.call_servable(x1, x2) return y """ servable_config_method_add_cast = r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32 y = register.call_servable(x1, x2) return y """ def init_add_servable(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += servable_config_method_add_common servable_content += servable_config_method_add_cast base.init_servable_with_servable_config(1, servable_content) return base def init_str_servable(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def preprocess(other): return np.ones([2,2], np.float32), np.ones([2,2], np.float32) def str_concat_postprocess(text1, text2): print("text1", text1, "text2", text2) return text1 + text2 @register.register_method(output_names=["text"]) def str_concat(text1, text2): text = register.add_stage(str_concat_postprocess, text1, text2, outputs_count=1) return text def str_empty_postprocess(text1, text2): if len(text1) == 0: text = text2 else: text = "" return text @register.register_method(output_names=["text"]) def str_empty(text1, text2): text = register.add_stage(str_empty_postprocess, text1, text2, outputs_count=1) return text """ base.init_servable_with_servable_config(1, servable_content) return base def init_bytes_servable(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def preprocess(other): return np.ones([2,2], np.float32), np.ones([2,2], np.float32) def bytes_concat_process(text1, text2): text1 = bytes.decode(text1.tobytes()) # bytes decode to str text2 = bytes.decode(text2.tobytes()) # bytes decode to str return str.encode(text1 + text2) # str encode to bytes @register.register_method(output_names=["text"]) def bytes_concat(text1, text2): text = register.add_stage(bytes_concat_process, text1, text2, outputs_count=1) return text def bytes_empty_process(text1, text2): text1 = bytes.decode(text1.tobytes()) # bytes decode to str text2 = bytes.decode(text2.tobytes()) # bytes decode to str if len(text1) == 0: text = text2 else: text = "" return str.encode(text) # str encode to bytes @register.register_method(output_names=["text"]) def bytes_empty(text1, text2): text = register.add_stage(bytes_empty_process, text1, text2, outputs_count=1) return text """ base.init_servable_with_servable_config(1, servable_content) return base def init_bool_int_float_servable(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def bool_process(bool_val): return ~bool_val @register.register_method(output_names=["value"]) def bool_not(bool_val): value = register.add_stage(bool_process, bool_val, outputs_count=1) return value def int_process(int_val): return int_val + 1 @register.register_method(output_names=["value"]) def int_plus_1(int_val): value = register.add_stage(int_process, int_val, outputs_count=1) return value def float_process(float_val): value = (float_val + 1).astype(float_val.dtype) # also support float16 input and output return value @register.register_method(output_names=["value"]) def float_plus_1(float_val): value = register.add_stage(float_process, float_val, outputs_count=1) return value """ base.init_servable_with_servable_config(1, servable_content) return base def start_serving_server(servable_content, model_file="tensor_add.mindir", version_number=1, start_version_number=None, device_ids=0, num_parallel_workers=0, device_type=None): base = ServingTestBase() base.init_servable_with_servable_config(version_number, servable_content, model_file=model_file) if start_version_number is None: start_version_number = version_number server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=device_ids, version_number=start_version_number, num_parallel_workers=num_parallel_workers, device_type=device_type)) server.start_grpc_server("0.0.0.0:5500") return base ================================================ FILE: tests/ut/python/tests/common_restful.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving, Common""" from multiprocessing import Process, Pipe import json import requests import numpy as np from common import init_str_servable, init_bytes_servable, init_bool_int_float_servable from mindspore_serving import server def compare_float_value(result, expect): if isinstance(expect, (float, int)): assert isinstance(result, float) assert abs(expect - result) < 0.001 return expect = np.array(expect) result = np.array(result) assert (np.abs(expect - result) < 0.001).all() def create_multi_instances_fp32(instance_count): instances = [] # instance 1 y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) return instances, y_data_list def create_multi_instances_with_batch_fp32(instance_count): instances = [] # instance 1 y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) return instances, y_data_list def check_number_result(result, y_data_list, output_name="y"): result = result["instances"] assert len(result) == len(y_data_list) for result_item, expected_item in zip(result, y_data_list): result_item = np.array(result_item[output_name]) print("result", result_item) print("expect:", expected_item) assert result_item.shape == expected_item.shape assert (np.abs(result_item - expected_item) < 0.001).all() def post_restful(address, servable_name, method_name, json_instances, version_number=None, verify="ca.crt", cert=("client.crt", "client.key"), https=False, post_payload=None): if not post_payload: instances_map = {"instances": json_instances} post_payload = json.dumps(instances_map) print("request:", post_payload[:200]) protocol = "http" if https: protocol = "https" def post_request(request_url, post_payload, send_pipe, verify=verify, cert=cert): try: if https: result = requests.post(request_url, data=post_payload, verify=verify, cert=cert) else: result = requests.post(request_url, data=post_payload) print(f"result inner: {result}") result = json.loads(result.text) send_pipe.send(result) # pylint: disable=broad-except except Exception as e: print(f"post failed: {e}") send_pipe.send("post failed") if version_number is not None: request_url = f"{protocol}://{address}/model/{servable_name}/version/{version_number}:{method_name}" else: request_url = f"{protocol}://{address}/model/{servable_name}:{method_name}" result = None for _ in range(2): send_pipe, recv_pipe = Pipe() sub_process = Process(target=post_request, args=(request_url, post_payload, send_pipe)) sub_process.start() sub_process.join() if recv_pipe.poll(0.1): result = recv_pipe.recv() if result != "post failed": break else: result = "post failed" print(f"result outer: {result}") return result def start_str_restful_server(): base = init_str_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") return base def start_bytes_restful_server(): base = init_bytes_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") return base def start_bool_int_float_restful_server(): base = init_bool_int_float_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") return base ================================================ FILE: tests/ut/python/tests/test_distributed_worker.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test distributed worker""" import logging import os import signal import time from multiprocessing import Process, Pipe import numpy as np import psutil from common import serving_test, create_client, ServingTestBase from mindspore_serving.server import distributed from mindspore_serving import server distributed_import = r""" import numpy as np from mindspore_serving.server import distributed from mindspore_serving.server import register """ distributed_declare_servable = r""" model = distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False) """ rank_table_content = r""" { "version": "1.0", "server_count": "1", "server_list": [ { "server_id": "127.0.0.1", "device": [ { "device_id": "0", "device_ip": "192.1.27.6", "rank_id": "0" }, { "device_id": "1", "device_ip": "192.2.27.6", "rank_id": "1" }, { "device_id": "2", "device_ip": "192.3.27.6", "rank_id": "2" }, { "device_id": "3", "device_ip": "192.4.27.6", "rank_id": "3" }, { "device_id": "4", "device_ip": "192.1.27.7", "rank_id": "4" }, { "device_id": "5", "device_ip": "192.2.27.7", "rank_id": "5" }, { "device_id": "6", "device_ip": "192.3.27.7", "rank_id": "6" }, { "device_id": "7", "device_ip": "192.4.27.7", "rank_id": "7" } ], "host_nic_ip": "reserve" } ], "status": "completed" } """ def init_distributed_servable(): base = ServingTestBase() servable_content = distributed_import servable_content += distributed_declare_servable servable_content += r""" @register.register_method(output_names=["y"]) def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base.init_distributed_servable(servable_content, 8, rank_table_content) return base def start_distributed_grpc_server(): base = init_distributed_servable() return base def start_distributed_worker(base): send_pipe, recv_pipe = Pipe() def worker_process(send_pipe): try: distributed.start_servable(base.servable_dir, base.servable_name, rank_table_json_file=base.rank_table_content_path, distributed_address="127.0.0.1:6200") server.start_grpc_server("0.0.0.0:5500") send_pipe.send("Success") # pylint: disable=broad-except except Exception as e: logging.exception(e) send_pipe.send(e) worker = Process(target=worker_process, args=(send_pipe,)) worker.start() time.sleep(0.5) # wait parse rank table ready assert worker.is_alive() return worker, recv_pipe def wait_worker_registered_ready(worker, recv_pipe): index = 0 while index < 100 and worker.is_alive(): # wait max 10 s index += 1 if recv_pipe.poll(0.1): msg = recv_pipe.recv() print(f"Receive worker process msg: {msg} {worker.is_alive()}") if isinstance(msg, Exception): raise msg break if recv_pipe.poll(0.1): msg = recv_pipe.recv() print(f"Receive worker process msg: {msg} {worker.is_alive()}") if isinstance(msg, Exception): raise msg assert index < 100 assert worker.is_alive() def start_agents(model_file_list, group_config_list, start_port, dec_key=None, dec_mode='AES-GCM'): send_pipe, recv_pipe = Pipe() def agent_process(send_pipe): try: distributed.startup_agents(distributed_address="127.0.0.1:6200", model_files=model_file_list, group_config_files=group_config_list, agent_start_port=start_port, dec_key=dec_key, dec_mode=dec_mode) send_pipe.send("Success") # pylint: disable=broad-except except Exception as e: logging.exception(e) send_pipe.send(e) agent = Process(target=agent_process, args=(send_pipe,)) agent.start() index = 0 while index < 100 and agent.is_alive(): # wait max 10 s index += 1 if recv_pipe.poll(0.1): msg = recv_pipe.recv() print(f"Receive agent process msg: {msg} {agent.is_alive()}") if isinstance(msg, Exception): raise msg break if recv_pipe.poll(0.1): msg = recv_pipe.recv() print(f"Receive agent process msg: {msg} {agent.is_alive()}") if isinstance(msg, Exception): raise msg assert index < 100 assert agent.is_alive() return agent def send_exit(process): if not process.is_alive(): return parent_process = psutil.Process(process.pid) child_processes = parent_process.children(recursive=True) def children_alive(): return any([item.is_running() for item in child_processes]) os.kill(process.pid, signal.SIGINT) for _ in range(50): # 50*0.1s if not process.is_alive() and not children_alive(): break time.sleep(0.1) for item in child_processes: if item.is_running(): os.kill(item.pid, signal.SIGKILL) if process.is_alive(): os.kill(process.pid, signal.SIGKILL) def start_distributed_serving_server(): base = start_distributed_grpc_server() worker_process, recv_pipe = start_distributed_worker(base) base.add_on_exit(lambda: send_exit(worker_process)) agent_process = start_agents(base.model_file_list, base.group_config_list, 7000) base.add_on_exit(lambda: send_exit(agent_process)) wait_worker_registered_ready(worker_process, recv_pipe) return base, worker_process, agent_process @serving_test def test_distributed_worker_worker_exit_success(): """ Feature: distributed serving server Description: Test distributed serving server exit when worker receive signal SIGINT Expectation: When worker receive signal SIGINT, serving server will exit. """ base, worker_process, agent_process = start_distributed_serving_server() client = create_client("localhost:5500", base.servable_name, "predict") instances = [{}, {}, {}] y_data_list = [] for index, instance in enumerate(instances): instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) y_data_list.append((instance["x1"] + instance["x2"]).tolist()) result = client.infer(instances) print(result) assert len(result) == 3 assert result[0]["y"].dtype == np.float32 assert result[1]["y"].dtype == np.float32 assert result[2]["y"].dtype == np.float32 assert result[0]["y"].tolist() == y_data_list[0] assert result[1]["y"].tolist() == y_data_list[1] assert result[2]["y"].tolist() == y_data_list[2] # send SIGINT to worker, expect worker and all agents exit agents = psutil.Process(agent_process.pid).children() def agents_alive(): return any([item.is_running() for item in agents]) os.kill(worker_process.pid, signal.SIGINT) for _ in range(50): # 50*0.1s if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): break time.sleep(0.1) assert not worker_process.is_alive() assert not agent_process.is_alive() assert not agents_alive() @serving_test def test_distributed_worker_agent_exit_success(): """ Feature: distributed serving server Description: Test distributed serving server exit when agent startup process receive signal SIGINT Expectation: When agent startup process receive signal SIGINT, serving server will exit. """ base, worker_process, agent_process = start_distributed_serving_server() client = create_client("localhost:5500", base.servable_name, "predict") instances = [{}, {}, {}] y_data_list = [] for index, instance in enumerate(instances): instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) y_data_list.append((instance["x1"] + instance["x2"]).tolist()) result = client.infer(instances) print(result) assert len(result) == 3 assert result[0]["y"].tolist() == y_data_list[0] assert result[1]["y"].tolist() == y_data_list[1] assert result[2]["y"].tolist() == y_data_list[2] # send SIGINT to worker, expect worker and all agents exit agents = psutil.Process(agent_process.pid).children() def agents_alive(): return any([item.is_running() for item in agents]) os.kill(agent_process.pid, signal.SIGINT) for _ in range(50): # 50*0.1s if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): break time.sleep(0.1) assert not worker_process.is_alive() assert not agent_process.is_alive() assert not agents_alive() @serving_test def test_distributed_worker_agent_startup_killed_exit_success(): """ Feature: distributed serving server Description: Test distributed serving server exit when agent start up process killed by signal SIGKILL Expectation: When agent startup process receive signal SIGKILL, serving server will exit. """ base, worker_process, agent_process = start_distributed_serving_server() client = create_client("localhost:5500", base.servable_name, "predict") instances = [{}, {}, {}] y_data_list = [] for index, instance in enumerate(instances): instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) y_data_list.append((instance["x1"] + instance["x2"]).tolist()) result = client.infer(instances) print(result) assert len(result) == 3 assert result[0]["y"].tolist() == y_data_list[0] assert result[1]["y"].tolist() == y_data_list[1] assert result[2]["y"].tolist() == y_data_list[2] # send SIGINT to worker, expect worker and all agents exit agents = psutil.Process(agent_process.pid).children() def agents_alive(): return any([item.is_running() for item in agents]) os.kill(agent_process.pid, signal.SIGKILL) # kill msg for _ in range(50): # 50*0.1s # test agent_process.is_alive() first, it will make agents(children) notify exit of their parent if not agent_process.is_alive() and not worker_process.is_alive() and not agents_alive(): break time.sleep(0.1) assert not worker_process.is_alive() assert not agent_process.is_alive() assert not agents_alive() @serving_test def test_distributed_worker_agent_killed_exit_success(): """ Feature: distributed serving server Description: Test distributed serving server exit when one of agents killed by signal SIGKILL Expectation: When one of agent process receive signal SIGKILL, serving server will exit. """ base, worker_process, agent_process = start_distributed_serving_server() client = create_client("localhost:5500", base.servable_name, "predict") instances = [{}, {}, {}] y_data_list = [] for index, instance in enumerate(instances): instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) y_data_list.append((instance["x1"] + instance["x2"]).tolist()) result = client.infer(instances) print(result) assert len(result) == 3 assert result[0]["y"].tolist() == y_data_list[0] assert result[1]["y"].tolist() == y_data_list[1] assert result[2]["y"].tolist() == y_data_list[2] # send SIGINT to worker, expect worker and all agents exit agents = psutil.Process(agent_process.pid).children() assert agents def agents_alive(): return any([item.is_running() for item in agents]) os.kill(agents[0].pid, signal.SIGKILL) # kill msg for _ in range(50): # 50*0.1s if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): break time.sleep(0.1) assert not worker_process.is_alive() assert not agent_process.is_alive() assert not agents_alive() @serving_test def test_distributed_worker_agent_invalid_model_files_failed(): """ Feature: distributed serving server Description: Test distributed serving server start up when model files are invalid Expectation: serving server raise runtime error. """ base = start_distributed_grpc_server() worker_process, _ = start_distributed_worker(base) base.add_on_exit(lambda: send_exit(worker_process)) base.model_file_list[0] = base.model_file_list[0] + "_error" try: start_agents(base.model_file_list, base.group_config_list, 7036) assert False # pylint: disable=broad-except except Exception as e: assert "Cannot access model file" in str(e) @serving_test def test_distributed_worker_dec_model_success(): """ Feature: distributed serving server Description: Test distributed serving server with dec models Expectation: serving server running ok. """ base = start_distributed_grpc_server() worker_process, recv_pipe = start_distributed_worker(base) base.add_on_exit(lambda: send_exit(worker_process)) agent_process = start_agents(base.model_file_list, base.group_config_list, 7000, dec_key=('abcd1234' * 3).encode()) base.add_on_exit(lambda: send_exit(agent_process)) wait_worker_registered_ready(worker_process, recv_pipe) client = create_client("localhost:5500", base.servable_name, "predict") instances = [{}, {}, {}] y_data_list = [] for index, instance in enumerate(instances): instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) y_data_list.append((instance["x1"] + instance["x2"]).tolist()) result = client.infer(instances) print(result) assert len(result) == 3 assert result[0]["y"].dtype == np.float32 assert result[1]["y"].dtype == np.float32 assert result[2]["y"].dtype == np.float32 assert result[0]["y"].tolist() == y_data_list[0] assert result[1]["y"].tolist() == y_data_list[1] assert result[2]["y"].tolist() == y_data_list[2] # send SIGINT to worker, expect worker and all agents exit agents = psutil.Process(agent_process.pid).children() def agents_alive(): return any([item.is_running() for item in agents]) os.kill(worker_process.pid, signal.SIGINT) for _ in range(50): # 50*0.1s if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): break time.sleep(0.1) assert not worker_process.is_alive() assert not agent_process.is_alive() assert not agents_alive() ================================================ FILE: tests/ut/python/tests/test_grpc_request.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving with master, worker and client""" import numpy as np from common import init_str_servable, init_bytes_servable, init_bool_int_float_servable from common import serving_test, create_client from mindspore_serving import server def start_str_grpc_server(): base = init_str_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") return base def start_bytes_grpc_server(): base = init_bytes_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") return base def start_bool_int_float_grpc_server(): base = init_bool_int_float_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") return base @serving_test def test_grpc_request_str_input_output_success(): base = start_str_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str_a[i] instance["text2"] = str_b[i] client = create_client("localhost:5500", base.servable_name, "str_concat") result = client.infer(instances) print("result", result) assert result[0]["text"] == str_a[0] + str_b[0] assert result[1]["text"] == str_a[1] + str_b[1] assert result[2]["text"] == str_a[2] + str_b[2] @serving_test def test_grpc_request_empty_str_input_output_success(): base = start_str_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str_a[i] instance["text2"] = str_b[i] client = create_client("localhost:5500", base.servable_name, "str_empty") result = client.infer(instances) assert result[0]["text"] == "" assert result[1]["text"] == "456" assert result[2]["text"] == "" @serving_test def test_grpc_request_str_shape1_list_input_failed(): base = start_str_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [str_a[i]] instance["text2"] = [str_b[i]] client = create_client("localhost:5500", base.servable_name, "str_concat") try: client.infer(instances) assert False except RuntimeError as e: assert "Not support value type " in str(e) @serving_test def test_grpc_request_str_np_1d_array_input_failed(): base = start_str_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = np.array([str_a[i], str_a[i]]) instance["text2"] = np.array([str_b[i], str_b[i]]) print(instance) client = create_client("localhost:5500", base.servable_name, "str_concat") try: client.infer(instances) assert False except RuntimeError as e: assert "Unknown data type" in str(e) @serving_test def test_grpc_request_bytes_input_output_success(): base = start_bytes_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str.encode(str_a[i]) instance["text2"] = str.encode(str_b[i]) client = create_client("localhost:5500", base.servable_name, "bytes_concat") result = client.infer(instances) assert bytes.decode(result[0]["text"]) == str_a[0] + str_b[0] assert bytes.decode(result[1]["text"]) == str_a[1] + str_b[1] assert bytes.decode(result[2]["text"]) == str_a[2] + str_b[2] @serving_test def test_grpc_request_empty_bytes_input_output_success(): base = start_bytes_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str.encode(str_a[i]) instance["text2"] = str.encode(str_b[i]) client = create_client("localhost:5500", base.servable_name, "bytes_empty") result = client.infer(instances) assert bytes.decode(result[0]["text"]) == "" assert bytes.decode(result[1]["text"]) == str_b[1] assert bytes.decode(result[2]["text"]) == "" @serving_test def test_grpc_request_bytes_1d_array_input_failed(): base = start_bytes_grpc_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [str.encode(str_a[i])] instance["text2"] = [str.encode(str_b[i])] client = create_client("localhost:5500", base.servable_name, "bytes_concat") try: client.infer(instances) assert False except RuntimeError as e: assert "Not support value type " in str(e) @serving_test def test_grpc_request_bool_scalar_input_output_success(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): instance["bool_val"] = (i % 2 == 0) client = create_client("localhost:5500", base.servable_name, "bool_not") result = client.infer(instances) assert not result[0]["value"] assert result[1]["value"] assert not result[2]["value"] @serving_test def test_grpc_request_bool_1d_array_input_output_success(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i % 2 == 0) val = [val] * i instance["bool_val"] = np.array(val).astype(np.bool) client = create_client("localhost:5500", base.servable_name, "bool_not") result = client.infer(instances) assert result[0]["value"].tolist() == [] assert result[1]["value"].tolist() == [True] assert result[2]["value"].tolist() == [False, False] @serving_test def test_grpc_request_bool_2d_array_input_output_success(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i % 2 == 0) val = [[val] * i] * i if i == 0: val = [[]] instance["bool_val"] = np.array(val).astype(np.bool) client = create_client("localhost:5500", base.servable_name, "bool_not") result = client.infer(instances) assert result[0]["value"].tolist() == [[]] assert result[1]["value"].tolist() == [[True]] assert result[2]["value"].tolist() == [[False, False], [False, False]] @serving_test def test_grpc_request_bool_invalid_2d_array_input_failed(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i % 2 == 0) val = [[val, val], [val]] instance["bool_val"] = np.array(val) client = create_client("localhost:5500", base.servable_name, "bool_not") try: client.infer(instances) assert False except RuntimeError as e: assert "Unknown data type object" in str(e) @serving_test def test_grpc_request_int_scalar_input_output_success(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) * (-1 if i % 2 == 0 else 1) # 0, 2, -4 instance["int_val"] = val client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"] == 1 assert result[1]["value"] == 3 assert result[2]["value"] == -3 def common_test_grpc_request_np_int_type_scalar_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) * (-1 if i % 2 == 0 else 1) # 0, 2, -4 instance["int_val"] = dtype(val) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"] == 1 assert result[1]["value"] == 3 assert result[2]["value"] == -3 @serving_test def test_grpc_request_np_int8_type_scalar_input_output_success(): common_test_grpc_request_np_int_type_scalar_input_output_success(np.int8) @serving_test def test_grpc_request_np_int16_type_scalar_input_output_success(): common_test_grpc_request_np_int_type_scalar_input_output_success(np.int16) @serving_test def test_grpc_request_np_int32_type_scalar_input_output_success(): common_test_grpc_request_np_int_type_scalar_input_output_success(np.int32) @serving_test def test_grpc_request_np_int64_type_scalar_input_output_success(): common_test_grpc_request_np_int_type_scalar_input_output_success(np.int64) def common_test_grpc_request_np_uint_type_scalar_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) # 0, 2, 4 instance["int_val"] = dtype(val) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"] == 1 assert result[1]["value"] == 3 assert result[2]["value"] == 5 @serving_test def test_grpc_request_np_uint8_type_scalar_input_output_success(): common_test_grpc_request_np_uint_type_scalar_input_output_success(np.uint8) @serving_test def test_grpc_request_np_uint16_type_scalar_input_output_success(): common_test_grpc_request_np_uint_type_scalar_input_output_success(np.uint16) @serving_test def test_grpc_request_np_uint32_type_scalar_input_output_success(): common_test_grpc_request_np_uint_type_scalar_input_output_success(np.uint32) @serving_test def test_grpc_request_np_uint64_type_scalar_input_output_success(): common_test_grpc_request_np_uint_type_scalar_input_output_success(np.uint64) def common_test_grpc_request_np_int_type_1d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) * (-1 if i % 2 == 0 else 1) # 0, 2, -4 val = [val] * i instance["int_val"] = np.array(val).astype(dtype) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"].tolist() == [] assert result[1]["value"].tolist() == [3] assert result[2]["value"].tolist() == [-3, -3] @serving_test def test_grpc_request_np_int8_type_1d_array_input_output_success(): common_test_grpc_request_np_int_type_1d_array_input_output_success(np.int8) @serving_test def test_grpc_request_np_int16_type_1d_array_input_output_success(): common_test_grpc_request_np_int_type_1d_array_input_output_success(np.int16) @serving_test def test_grpc_request_np_int32_type_1d_array_input_output_success(): common_test_grpc_request_np_int_type_1d_array_input_output_success(np.int32) @serving_test def test_grpc_request_np_int64_type_1d_array_input_output_success(): common_test_grpc_request_np_int_type_1d_array_input_output_success(np.int64) def common_test_grpc_request_np_uint_type_1d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) # 0, 2, 4 val = [val] * i instance["int_val"] = np.array(val).astype(dtype) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"].tolist() == [] assert result[1]["value"].tolist() == [3] assert result[2]["value"].tolist() == [5, 5] @serving_test def test_grpc_request_np_uint8_type_1d_array_input_output_success(): common_test_grpc_request_np_uint_type_1d_array_input_output_success(np.uint8) @serving_test def test_grpc_request_np_uint16_type_1d_array_input_output_success(): common_test_grpc_request_np_uint_type_1d_array_input_output_success(np.uint16) @serving_test def test_grpc_request_np_uint32_type_1d_array_input_output_success(): common_test_grpc_request_np_uint_type_1d_array_input_output_success(np.uint32) @serving_test def test_grpc_request_np_uint64_type_1d_array_input_output_success(): common_test_grpc_request_np_uint_type_1d_array_input_output_success(np.uint64) def common_test_grpc_request_np_int_type_2d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) * (-1 if i % 2 == 0 else 1) # 0, 2, -4 val = [[val] * i] * i if i == 0: val = [[]] instance["int_val"] = np.array(val).astype(dtype) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"].tolist() == [[]] assert result[1]["value"].tolist() == [[3]] assert result[2]["value"].tolist() == [[-3, -3], [-3, -3]] @serving_test def test_grpc_request_np_int8_type_2d_array_input_output_success(): common_test_grpc_request_np_int_type_2d_array_input_output_success(np.int8) @serving_test def test_grpc_request_np_int16_type_2d_array_input_output_success(): common_test_grpc_request_np_int_type_2d_array_input_output_success(np.int16) @serving_test def test_grpc_request_np_int32_type_2d_array_input_output_success(): common_test_grpc_request_np_int_type_2d_array_input_output_success(np.int32) @serving_test def test_grpc_request_np_int64_type_2d_array_input_output_success(): common_test_grpc_request_np_int_type_2d_array_input_output_success(np.int64) def common_test_grpc_request_np_uint_type_2d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i * 2) # 0, 2, 4 val = [[val] * i] * i if i == 0: val = [[]] instance["int_val"] = np.array(val).astype(dtype) client = create_client("localhost:5500", base.servable_name, "int_plus_1") result = client.infer(instances) assert result[0]["value"].tolist() == [[]] assert result[1]["value"].tolist() == [[3]] assert result[2]["value"].tolist() == [[5, 5], [5, 5]] @serving_test def test_grpc_request_np_uint8_type_2d_array_input_output_success(): common_test_grpc_request_np_uint_type_2d_array_input_output_success(np.uint8) @serving_test def test_grpc_request_np_uint16_type_2d_array_input_output_success(): common_test_grpc_request_np_uint_type_2d_array_input_output_success(np.uint16) @serving_test def test_grpc_request_np_uint32_type_2d_array_input_output_success(): common_test_grpc_request_np_uint_type_2d_array_input_output_success(np.uint32) @serving_test def test_grpc_request_np_uint64_type_2d_array_input_output_success(): common_test_grpc_request_np_uint_type_2d_array_input_output_success(np.uint64) @serving_test def test_grpc_request_float_scalar_input_output_success(): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): instance["float_val"] = i * 2.2 client = create_client("localhost:5500", base.servable_name, "float_plus_1") result = client.infer(instances) assert result[0]["value"] == 1 assert result[1]["value"] == (2.2 + 1) assert result[2]["value"] == (4.4 + 1) def common_test_grpc_request_np_float_type_scalar_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] y_data_list = [] for i, instance in enumerate(instances): val = (i * 2.2) * (-1 if i % 2 == 0 else 1) # 0, 2.2, -4.4 val = np.array(val).astype(dtype) y_data_list.append((val + 1).tolist()) instance["float_val"] = val client = create_client("localhost:5500", base.servable_name, "float_plus_1") result = client.infer(instances) assert result[0]["value"].dtype == dtype assert result[1]["value"].dtype == dtype assert result[2]["value"].dtype == dtype assert result[0]["value"].tolist() == y_data_list[0] assert result[1]["value"].tolist() == y_data_list[1] assert result[2]["value"].tolist() == y_data_list[2] @serving_test def test_grpc_request_np_float16_scalar_input_output_success(): common_test_grpc_request_np_float_type_scalar_input_output_success(np.float16) @serving_test def test_grpc_request_np_float32_scalar_input_output_success(): common_test_grpc_request_np_float_type_scalar_input_output_success(np.float32) @serving_test def test_grpc_request_np_float64_scalar_input_output_success(): common_test_grpc_request_np_float_type_scalar_input_output_success(np.float64) def common_test_grpc_request_np_float_type_1d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] y_data_list = [] for i, instance in enumerate(instances): val = (i * 2.2) * (-1 if i % 2 == 0 else 1) # 0, 2.2, -4.4 val = [val] * i val = np.array(val).astype(dtype) y_data_list.append((val + 1).tolist()) instance["float_val"] = val client = create_client("localhost:5500", base.servable_name, "float_plus_1") result = client.infer(instances) assert result[0]["value"].dtype == dtype assert result[1]["value"].dtype == dtype assert result[2]["value"].dtype == dtype assert result[0]["value"].tolist() == y_data_list[0] assert result[1]["value"].tolist() == y_data_list[1] assert result[2]["value"].tolist() == y_data_list[2] @serving_test def test_grpc_request_np_float16_1d_array_input_output_success(): common_test_grpc_request_np_float_type_1d_array_input_output_success(np.float16) @serving_test def test_grpc_request_np_float32_1d_array_input_output_success(): common_test_grpc_request_np_float_type_1d_array_input_output_success(np.float32) @serving_test def test_grpc_request_np_float64_1d_array_input_output_success(): common_test_grpc_request_np_float_type_1d_array_input_output_success(np.float64) def common_test_grpc_request_np_float_type_2d_array_input_output_success(dtype): base = start_bool_int_float_grpc_server() # Client instances = [{}, {}, {}] y_data_list = [] for i, instance in enumerate(instances): val = (i * 2.2) * (-1 if i % 2 == 0 else 1) # 0, 2.2, -4.4 val = [[val] * i] * i if i == 0: val = [[]] val = np.array(val).astype(dtype) y_data_list.append((val + 1).tolist()) instance["float_val"] = val client = create_client("localhost:5500", base.servable_name, "float_plus_1") result = client.infer(instances) assert result[0]["value"].dtype == dtype assert result[1]["value"].dtype == dtype assert result[2]["value"].dtype == dtype assert result[0]["value"].tolist() == y_data_list[0] assert result[1]["value"].tolist() == y_data_list[1] assert result[2]["value"].tolist() == y_data_list[2] @serving_test def test_grpc_request_np_float16_2d_array_input_output_success(): common_test_grpc_request_np_float_type_2d_array_input_output_success(np.float16) @serving_test def test_grpc_request_np_float32_2d_array_input_output_success(): common_test_grpc_request_np_float_type_2d_array_input_output_success(np.float32) @serving_test def test_grpc_request_np_float64_2d_array_input_output_success(): common_test_grpc_request_np_float_type_2d_array_input_output_success(np.float64) @serving_test def test_grpc_request_unix_domain_socket_success(): base = init_str_servable() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server_address = "unix:unix_socket_files/test_grpc_request_unix_domain_socket_success" server.start_grpc_server(server_address) # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str_a[i] instance["text2"] = str_b[i] client = create_client(server_address, base.servable_name, "str_concat") result = client.infer(instances) print("result", result) assert result[0]["text"] == str_a[0] + str_b[0] assert result[1]["text"] == str_a[1] + str_b[1] assert result[2]["text"] == str_a[2] + str_b[2] ================================================ FILE: tests/ut/python/tests/test_model_call.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving pipeline with client""" import numpy as np from common import start_serving_server from common import serving_test, create_client @serving_test def test_call_model_two_input_one_output_normal_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call(x1, x2) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_two_input_one_output_multi_times_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): y1 = model.call(x1, x2) y2 = model.call(x3, x4) return y1 + y2 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_two_input_one_output_multi_times_2success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): y1 = model.call(x1, x2) y2 = model.call(x3, x4) y = model.call(y1, y2) return y @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_two_input_one_output_batch_call_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): instances = [] instances.append([x1, x2]) instances.append((x3, x4)) outputs = model.call(instances) # return [[x1+x2], [x3+x4]] y1 = outputs[0][0] y2 = outputs[1][0] instances = [] instances.append((y1, y2)) outputs = model.call(instances) y = outputs[0][0] return y @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_batch_call_one_input_one_output_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add_1_1.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): instances = [] instances.append([x1]) instances.append([x2]) instances.append([x3]) outputs = model.call(instances) y1 = outputs[0][0] y2 = outputs[1][0] y3 = outputs[2][0] y4 = model.call(x4) return y1+y2+y3+y4 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add_1_1.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_batch_call_one_input_two_output_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add_1_2.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): _, y1 = model.call(x1) # one instance _, y2 = model.call(x2) # one instance _, y3 = model.call(x3) # one instance _, y4 = model.call(x4) # one instance return y1+y2+y3+y4 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add_1_2.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 + 4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_batch_call_one_input_two_output_batch_call_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add_1_2.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): instances = [] instances.append([x1]) # one input outputs = model.call(instances) # batch call, one instance _, y1 = outputs[0] instances = [] instances.append([x2]) # one input outputs = model.call(instances) # batch call, one instance _, y2 = outputs[0] instances = [] instances.append([x3]) # one input instances.append([x4]) outputs = model.call(instances) # batch call, two instances _, y3 = outputs[0] _, y4 = outputs[1] return y1+y2+y3+y4 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add_1_2.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 + 4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_two_input_one_output_none_instances_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call() return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Model(tensor_add.mindir).call() failed: no inputs provided" in result["error"] @serving_test def test_call_model_two_input_one_output_zero_instances_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call([]) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Model(tensor_add.mindir).call() failed: Input instances count is 0" in result["error"] @serving_test def test_call_model_two_input_one_output_invalid_inputs_format_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call([x1, x2]) # expect to be model.call([[x1, x2]]) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "inputs format invalid" in result["error"] @serving_test def test_call_model_two_input_one_output_zero_inputs_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call([[]]) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 0 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_data_size_error_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call(x1, x2) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2, 3.3], [3.3, 4.4, 5.5]], np.float32) x2 = np.array([[5.5, 6.6, 7.7], [7.7, 8.8, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Given model input 0 size 24 not match the size 16 defined in model" in result["error"] @serving_test def test_call_model_two_input_one_output_data_type_error_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call(x1, x2) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.int32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.int32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Given model input 0 data type kMSI_Int32 not match the data type kMSI_Float32 defined in model" in \ result["error"] @serving_test def test_call_model_two_input_one_output_call_batch_data_size_error_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): instances = [] instances.append((x1, x2)) instances.append((x3, x4)) ys = model.call(instances) return ys[0][0] + ys[1][0] @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2, 3.3], [3.3, 4.4, 5.5]], np.float32) x2 = np.array([[5.5, 6.6, 7.7], [7.7, 8.8, 8.8]], np.float32) x3 = np.array([[1.1, 2.2, 3.3], [3.3, 4.4, 5.5]], np.float32) x4 = np.array([[5.5, 6.6, 7.7], [7.7, 8.8, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Given model input 0 size 24 not match the size 16 defined in model" in result["error"] @serving_test def test_call_model_two_input_one_output_call_batch_data_type_error_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4): instances = [] instances.append((x1, x2)) instances.append((x3, x4)) ys = model.call(instances) return ys[0][0] + ys[1][0] @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(call_model, x1, x2, x3, x4, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.int32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.int32) x3 = np.array([[1.1, 2.2], [3.3, 4.4]], np.int32) x4 = np.array([[5.5, 6.6], [7.7, 8.8]], np.int32) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Given model input 0 data type kMSI_Int32 not match the data type kMSI_Float32 defined in model" in \ result["error"] @serving_test def test_call_model_two_input_one_output_more_inputs_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call(x1, x2, x3) return y @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 3 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_batch_call_more_inputs_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call([[x1, x2, x3]]) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 3 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_batch_call_more_inputs2_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call([[x1, x2], [x1, x2, x3]]) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 3 of instance 1 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_less_inputs_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call(x1) return y @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 1 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_batch_call_less_inputs_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call([[x1]]) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 1 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_batch_call_less_inputs2_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call([[x1], [x1, x2]]) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 1 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_batch_call_less_inputs3_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call([[x1, x2], [x1]]) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 1 of instance 1 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_invalid_model_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) model_invalid = register.model.Model("tensor_add_test.mindir") def call_model(x1, x2): y = model_invalid.call(x1, x2) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "Model(tensor_add_test.mindir).call() failed: the model is not declared" in result["error"] @serving_test def test_call_model_two_input_one_output_with_stage_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y1 = model.call(x1, x2) return y1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y1 = register.add_stage(call_model, x1, x2, outputs_count=1) y2 = register.add_stage(model, y1, x3, outputs_count=1) y3 = register.add_stage(call_model, y2, x4, outputs_count=1) return y3 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[2.1, 3.2], [4.3, 5.4]], np.float32) x4 = np.array([[3.5, 4.6], [5.7, 6.8]], np.float32) y = x1 + x2 + x3 + x4 instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_two_input_one_output_invalid_subgraph_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call(x1, x2, subgraph=1) return y[0][0] @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The model does not have subgraph of index 1, the subgraph count of the model is 1" in result["error"] @serving_test def test_call_model_two_input_one_output_two_subgraph_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file=["tensor_add.mindir", "tensor_sub.mindir"], model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = model.call(x1, x2, subgraph=0) # x1+x2 y = model.call(y, x3, subgraph=1) # y-x3 return y @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) y = x1 + x2 - x3 instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_subgraph_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file=["tensor_add_2_3.mindir", "tensor_sub_3_2.mindir"], model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2, y3 = model.call(x1, x2, subgraph=0) # tensor_add_2_3: 2 input, 3 output y4, y5 = model.call(x3, x4, x5, subgraph=1) # tensor_sub_3_2: 3 input, 2 output return y1+y4 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_2_3.mindir", "tensor_sub_3_2.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) y = (x1 + x2) + (x3 - x4 - x5) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_subgraph2_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"], model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2 = model.call(x1, x2, x3, subgraph=0) # tensor_add_3_2: 3 input, 2 output y3, y4, y5 = model.call(x4, x5, subgraph=1) # tensor_sub_2_3: 2 input, 3 output return y1+y3 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) y = (x1 + x2 + x3) + (x4 - x5) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_subgraph_inputs_count_not_match_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"], model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2 = model.call(x1, x2, x3, subgraph=0) # tensor_add_3_2: 3 input, 2 output y3, y4, y5 = model.call(x4, x5, x3, subgraph=1) # tensor_sub_2_3: 2 input, 3 output return y1+y3 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 3 of instance 0 is not equal to the inputs count 2 of the model" in result["error"] @serving_test def test_call_model_two_input_one_output_two_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3): y = tensor_add.call(x1, x2) # x1+x2 y = tensor_sub.call(y, x3) # y-x3 return y @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(call_model, x1, x2, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) y = x1 + x2 - x3 instances = [{"x1": x1, "x2": x2, "x3": x3}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_2_3.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub_3_2.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2, y3 = tensor_add.call(x1, x2) # tensor_add_2_3: 2 input, 3 output y4, y5 = tensor_sub.call(x3, x4, x5) # tensor_sub_3_2: 3 input, 2 output return y1+y4 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_2_3.mindir", "tensor_sub_3_2.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) y = (x1 + x2) + (x3 - x4 - x5) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_model2_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_3_2.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub_2_3.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2 = tensor_add.call(x1, x2, x3) # tensor_add_3_2: 3 input, 2 output y3, y4, y5 = tensor_sub.call(x4, x5) # tensor_sub_2_3: 2 input, 3 output return y1+y3 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) y = (x1 + x2 + x3) + (x4 - x5) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_call_model_diff_input_output_two_model_inputs_count_not_match_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_3_2.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub_2_3.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2, x3, x4, x5): y1, y2 = tensor_add.call(x1, x2) # tensor_add_3_2: 3 input, 2 output y3, y4, y5 = tensor_sub.call(x4, x5) # tensor_sub_2_3: 2 input, 3 output return y1+y3 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) x3 = np.array([[7.5, 8.6], [9.7, 10.8]], np.float32) x4 = np.array([[8.5, 10.6], [6.7, 12.8]], np.float32) x5 = np.array([[9.5, 11.6], [8.7, 13.8]], np.float32) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert "The inputs count 2 of instance 0 is not equal to the inputs count 3 of the model" in result["error"] @serving_test def test_call_model_diff_input_output_two_model_with_bach_dim_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_3_2.mindir", model_format="MindIR", with_batch_dim=True) tensor_sub = register.declare_model(model_file="tensor_sub_2_3.mindir", model_format="MindIR", with_batch_dim=True) def call_model(x1, x2, x3, x4, x5): y1, y2 = tensor_add.call(x1, x2, x3) # tensor_add_3_2: 3 input, 2 output y3, y4, y5 = tensor_sub.call(x4, x5) # tensor_sub_2_3: 2 input, 3 output return y1+y3 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(call_model, x1, x2, x3, x4, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client x1 = np.array([[3.3, 4.4]], np.float32) x2 = np.array([[7.7, 8.8]], np.float32) x3 = np.array([[9.7, 10.8]], np.float32) x4 = np.array([[6.7, 12.8]], np.float32) x5 = np.array([[8.7, 13.8]], np.float32) y = (x1 + x2 + x3) + (x4 - x5) instances = [{"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() ================================================ FILE: tests/ut/python/tests/test_model_context.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Test Model DeviceInfo""" import os import numpy as np from common import serving_test, start_serving_server, create_client from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions @serving_test def test_model_context_device_info_set_get_success(): """ Feature: Model Device info Description: Test set and get device info Expectation: the values gotten are equal to the values set. """ try: context = Context(thread_affinity_core_list=1) assert False except RuntimeError as e: assert "Parameter 'thread_affinity_core_list' should be tuple/list of int, but actually " in str(e) context = Context(thread_num=3, thread_affinity_core_list=[1, 2, 3], enable_parallel=True) model_context = context.model_context assert model_context.thread_num == 3 assert set(model_context.thread_affinity_core_list) == {1, 2, 3} assert model_context.enable_parallel == 1 # declare model and start_servable and load model and build model gpu_device_info = GPUDeviceInfo(precision_mode="fp16") gpu_map = gpu_device_info.context_map assert gpu_map["precision_mode"] == "fp16" assert gpu_map["device_type"] == "gpu" cpu_device_info = CPUDeviceInfo(precision_mode="fp16") cpu_map = cpu_device_info.context_map assert cpu_map["precision_mode"] == "fp16" assert cpu_map["device_type"] == "cpu" ascend_device_info = AscendDeviceInfo(insert_op_cfg_path="some path of insert_op_cfg_path", input_format="NHWC1C0", input_shape="input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1", output_type="FP16", precision_mode="allow_mix_precision", op_select_impl_mode="high_precision", fusion_switch_config_path="some path of fusion_switch_config_path", buffer_optimize_mode="l1_and_l2_optimize") ascend310_map = ascend_device_info.context_map assert ascend310_map["insert_op_cfg_path"] == "some path of insert_op_cfg_path" assert ascend310_map["input_format"] == "NHWC1C0" assert ascend310_map["input_shape"] == "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1" assert ascend310_map["output_type"] == "FP16" assert ascend310_map["precision_mode"] == "allow_mix_precision" assert ascend310_map["op_select_impl_mode"] == "high_precision" assert ascend310_map["fusion_switch_config_path"] == "some path of fusion_switch_config_path" assert ascend310_map["buffer_optimize_mode"] == "l1_and_l2_optimize" assert ascend310_map["device_type"] == "ascend" context.append_device_info(gpu_device_info) context.append_device_info(cpu_device_info) context.append_device_info(ascend_device_info) assert len(model_context.device_list) == 3 assert model_context.device_list[0]["device_type"] == "gpu" assert model_context.device_list[1]["precision_mode"] == "fp16" assert model_context.device_list[2]["precision_mode"] == "allow_mix_precision" @serving_test def test_model_context_device_info_repeat_append_ascend_failed(): """ Feature: Model Device info Description: Repeat append AscendDeviceInfo Expectation: raise RuntimeError """ context = Context() context.append_device_info(AscendDeviceInfo()) try: context.append_device_info(AscendDeviceInfo()) assert False except RuntimeError as e: assert "Device info of type ascend has already been appended" in str(e) @serving_test def test_model_context_options_set_get_success(): """ Feature: Model options Description: Test set and get options Expectation: the values gotten are equal to the values set. """ gpu_options = GpuOptions(precision_mode="fp16") gpu_device_list = gpu_options.context.model_context.device_list assert gpu_device_list[0]["device_type"] == "gpu" assert gpu_device_list[0]["precision_mode"] == "fp16" acl_options = AclOptions(insert_op_cfg_path="some path of insert_op_cfg_path", input_format="NHWC1C0", input_shape="input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1", output_type="FP16", precision_mode="allow_mix_precision", op_select_impl_mode="high_precision", fusion_switch_config_path="some path of fusion_switch_config_path", buffer_optimize_mode="l1_and_l2_optimize") acl_device_list = acl_options.context.model_context.device_list assert acl_device_list[0]["insert_op_cfg_path"] == "some path of insert_op_cfg_path" assert acl_device_list[0]["input_format"] == "NHWC1C0" assert acl_device_list[0]["input_shape"] == "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1" assert acl_device_list[0]["output_type"] == "FP16" assert acl_device_list[0]["precision_mode"] == "allow_mix_precision" assert acl_device_list[0]["op_select_impl_mode"] == "high_precision" assert acl_device_list[0]["fusion_switch_config_path"] == "some path of fusion_switch_config_path" assert acl_device_list[0]["buffer_optimize_mode"] == "l1_and_l2_optimize" assert acl_device_list[0]["device_type"] == "ascend" @serving_test def test_model_context_gpu_device_info_serving_server_success(): """ Feature: Model Device info Description: Test set gpu device info Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions context = Context() context.append_device_info(GPUDeviceInfo(precision_mode="fp16")) model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, context = context) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" base = start_serving_server(servable_content, device_type="GPU") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_cpu_device_info_serving_server_success(): """ Feature: Model Device info Description: Test set cpu device info Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions context = Context() context.append_device_info(CPUDeviceInfo(precision_mode="fp16")) model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, context = context) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ os.environ["SERVING_ENABLE_CPU_DEVICE"] = "1" base = start_serving_server(servable_content, device_type="CPU") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_ascend_device_info_serving_server_success(): """ Feature: Model Device info Description: Test set ascend device info Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions context = Context() context.append_device_info(AscendDeviceInfo(input_format="NHWC1C0")) model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, context = context) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, device_type="Ascend") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_all_device_info_serving_server_success(): """ Feature: Model Device info Description: Test set cpu, gpu, ascend device info, and serving select one device info based on inference so Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions context = Context() context.append_device_info(AscendDeviceInfo(input_format="NHWC1C0")) context.append_device_info(GPUDeviceInfo(precision_mode="fp16")) context.append_device_info(CPUDeviceInfo(precision_mode="fp16")) model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, context = context) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_acl_options_serving_server_success(): """ Feature: Model Device info Description: Test set ascend options Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions options = AclOptions(input_format="NHWC1C0") model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, options = options) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_gpu_options_serving_server_success(): """ Feature: Model Device info Description: Test set gpu options Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions options = GpuOptions(precision_mode="fp16") model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, options = options) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_gpu_options_invalid_parameter_failed(): """ Feature: Model Device info Description: Test set gpu options Expectation: Serving server start failed. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions options = GpuOptions(precision_mode="origi") model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, options = options) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Gpu device info 'precision_mode' can only be 'origin', 'fp16'" in str(e) @serving_test def test_model_context_gpu_options_invalid_parameter2_failed(): """ Feature: Model Device info Description: Test set gpu options Expectation: Serving server start failed. """ servable_content = r""" import numpy as np from mindspore_serving.server import register from mindspore_serving.server.register import Context, GPUDeviceInfo, CPUDeviceInfo from mindspore_serving.server.register import AscendDeviceInfo, GpuOptions, AclOptions options = GpuOptions(precision_xxx_mode="origin") model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False, options = options) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Set gpu device info failed, unsupported option precision_xxx_mode" in str(e) @serving_test def test_model_context_gpu_cpu_device_device_ids_none_serving_server_success(): """ Feature: Model Device info Description: device_ids=None, and support GPU, CPU, running on CPU Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" os.environ["SERVING_ENABLE_CPU_DEVICE"] = "1" base = start_serving_server(servable_content, device_type=None, device_ids=None) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_model_context_only_support_gpu_device_device_ids_none_serving_server_failed(): """ Feature: Model Device info Description: device_ids=None, and only support GPU, running on CPU failed Expectation: Serving server startup failed. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) return y """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" try: start_serving_server(servable_content, device_type=None, device_ids=None) except RuntimeError as e: assert "has models declared by declare_model, but parameter 'device_ids' of ServableStartConfig is not set in" \ " Serving startup script when the MindSpore or Lite inference package not support CPU" in str(e) ================================================ FILE: tests/ut/python/tests/test_multi_model.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import numpy as np from common import serving_test, create_client from common import start_serving_server def is_float_equal(left, right): return (np.abs(left - right) < 0.00001).all() @serving_test def test_multi_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 - x3 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_2_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) y = register.add_stage(tensor_add, y, x4, outputs_count=1) y = register.add_stage(tensor_sub, y, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(10): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) y = x1 + x2 - x3 + x4 - x5 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(10): assert is_float_equal(result[i]["y"], ys[i]) @serving_test def test_multi_model_with_batch_dim_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[6.2, 5.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 - x3 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_with_function_front_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def test(x1, x2): return x1+x2+1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(test, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) y = register.add_stage(tensor_add, y, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + 1 - x3 + x4 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_with_function_tail_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def test(x1, x2): return x1+x2+1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(tensor_sub, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) y = register.add_stage(test, y, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) y = x1 - x2 + x3 + x4 + 1 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_with_function_mid_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def test(x1, x2): return x1+x2+1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(tensor_sub, x1, x2, outputs_count=1) y = register.add_stage(test, y, x3, outputs_count=1) y = register.add_stage(tensor_add, y, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) y = x1 - x2 + x3 + 1 + x4 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_with_function_interlace_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def test(x1, x2): return x1+x2+1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5, x6): y = register.add_stage(test, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) y = register.add_stage(test, y, x4, outputs_count=1) y = register.add_stage(tensor_add, y, x5, outputs_count=1) y = register.add_stage(test, y, x6, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) x6 = np.array([[3.7, 4.8], [5.9, 6.0]], np.float32) * 1.1 * (i + 1) y = x1 + x2 - x3 + x4 + x5 + x6 + 3 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5, "x6": x6}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_with_function_call_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) tensor_sub = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return tensor_add.call(x1, x2) def sub_test(x1, x2): return tensor_sub.call(x1, x2) @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(tensor_sub, y, x3, outputs_count=1) y = register.add_stage(tensor_add, y, x4, outputs_count=1) y = register.add_stage(sub_test, y, x5, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[2.5, 3.3], [4.2, 5.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[2.7, 3.8], [4.9, 5.0]], np.float32) * 1.1 * (i + 1) y = x1 + x2 - x3 + x4 - x5 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_multi_model_diff_input_output_count_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_3_2.mindir", model_format="MindIR", with_batch_dim=True) tensor_sub = register.declare_model(model_file="tensor_sub_2_3.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y1", "y2", "y3"]) def predict(x1, x2, x3): y1, y2 = register.add_stage(tensor_add, x1, x2, x3, outputs_count=2) y1, y2, y3 = register.add_stage(tensor_sub, y1, y2, outputs_count=3) return y1, y2, y3 """ base = start_serving_server(servable_content, model_file=["tensor_add_3_2.mindir", "tensor_sub_2_3.mindir"]) # Client instances = [] ys = [] for i in range(3): x1 = np.array([[3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[6.2, 5.4]], np.float32) * 1.1 * (i + 1) # for tensor_add_3_2 y1 = x1 + x2 + x3 y2 = y1 + 1 # for tensor_sub_2_3 y1 = y1 - y2 y2 = y1 + 1 y3 = y1 + 2 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append([y1, y2, y3]) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y1"], ys[0][0]) assert is_float_equal(result[0]["y2"], ys[0][1]) assert is_float_equal(result[0]["y3"], ys[0][2]) assert is_float_equal(result[1]["y1"], ys[1][0]) assert is_float_equal(result[1]["y2"], ys[1][1]) assert is_float_equal(result[1]["y3"], ys[1][2]) assert is_float_equal(result[2]["y1"], ys[2][0]) assert is_float_equal(result[2]["y2"], ys[2][1]) assert is_float_equal(result[2]["y3"], ys[2][2]) ================================================ FILE: tests/ut/python/tests/test_python_parallel.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import numpy as np from common import ServingTestBase from common import serving_test, create_client from mindspore_serving import server def start_serving_server(servable_content, model_file="tensor_add.mindir", parallel_number=0, device_ids=0): base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content, model_file=model_file) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=device_ids, num_parallel_workers=parallel_number, version_number=1)) server.start_grpc_server("0.0.0.0:5500") return base def is_float_equal(left, right): return (np.abs(left - right) < 0.00001).all() def check_infer_log(servable_name, version, device_id, extra_id): if device_id is not None: log_file = f"serving_logs/log_{servable_name}_device{device_id}_version{version}.log" else: log_file = f"serving_logs/log_{servable_name}_extra{extra_id}_version{version}.log" if not os.path.isfile(log_file): print(f"Not found log file {log_file}", flush=True) return False with open(log_file) as fp: text = fp.read() if "WorkerRequestHandle Time Cost" not in text: print(f"Not found log 'WorkerRequestHandle Time Cost' in log file {log_file}", flush=True) return False print(f"Found log 'WorkerRequestHandle Time Cost' in log file {log_file}", flush=True) return True @serving_test def test_python_parallel_without_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2}) ys.append(x1 + x2) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) @serving_test def test_python_parallel_with_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def function_test(x1, x2): return x1+1, x2+1 def function_test2(y): return y + 1 @register.register_method(output_names="y") def predict(x1, x2): x1, x2 = register.add_stage(function_test, x1, x2, outputs_count=2) y = register.add_stage(model, x1, x2, outputs_count=1) y = register.add_stage(function_test2, y, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2}) ys.append(x1 + x2 + 3) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) @serving_test def test_python_parallel_with_call_model_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def function_call_model(x1, x2): return model.call(x1, x2) @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y1 = register.add_stage(function_call_model, x1, x2, outputs_count=1) y2 = register.add_stage(model, x3, x4, outputs_count=1) y = register.add_stage(function_call_model, y1, y2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) x3 = np.array([[3.1, 4.2], [5.3, 6.4]], np.float32) * (i + 1) x4 = np.array([[0.5, 9.6], [8.7, 7.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4}) y = (x1 + x2) + (x3 + x4) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) @serving_test def test_python_parallel_with_call_model_multi_process_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def function_call_model(x1, x2): return model.call(x1, x2) @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y1 = register.add_stage(function_call_model, x1, x2, outputs_count=1) y2 = register.add_stage(model, x3, x4, outputs_count=1) y = register.add_stage(function_call_model, y1, y2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=4, device_ids=(0, 1)) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) x3 = np.array([[3.1, 4.2], [5.3, 6.4]], np.float32) * (i + 1) x4 = np.array([[0.5, 9.6], [8.7, 7.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4}) y = (x1 + x2) + (x3 + x4) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=1, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=1) @serving_test def test_python_parallel_with_call_model_with_batch_size_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def function_call_model(x1, x2): return model.call(x1, x2) @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y1 = register.add_stage(function_call_model, x1, x2, outputs_count=1) y2 = register.add_stage(model, x3, x4, outputs_count=1) y = register.add_stage(function_call_model, y1, y2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[7.7, 8.8]], np.float32) * (i + 1) x3 = np.array([[5.3, 6.4]], np.float32) * (i + 1) x4 = np.array([[8.7, 7.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4}) y = (x1 + x2) + (x3 + x4) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) @serving_test def test_python_parallel_multi_models_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register add_model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) sub_model = register.declare_model(model_file="tensor_sub.mindir", model_format="MindIR", with_batch_dim=False) def function_call_model(x1, x2): return add_model.call(x1, x2) @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y1 = register.add_stage(add_model, x1, x2, outputs_count=1) y2 = register.add_stage(sub_model, x3, x4, outputs_count=1) y = register.add_stage(function_call_model, y1, y2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) x3 = np.array([[3.1, 4.2], [5.3, 6.4]], np.float32) * (i + 1) x4 = np.array([[0.5, 9.6], [8.7, 7.8]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4}) y = (x1 + x2) + (x3 - x4) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) @serving_test def test_python_parallel_multi_models_diff_input_output_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register add_model = register.declare_model(model_file="tensor_add_2_3.mindir", model_format="MindIR", with_batch_dim=False) sub_model = register.declare_model(model_file="tensor_sub_3_2.mindir", model_format="MindIR", with_batch_dim=False) def function_call_model(x1, x2): return x1 + x2 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5): _,y1,_ = register.add_stage(add_model, x1, x2, outputs_count=3) # 2 input, 3 output _, y2 = register.add_stage(sub_model, x3, x4, x5, outputs_count=2) # 3 input, 2 output y = register.add_stage(function_call_model, y1, y2, outputs_count=1) return y """ base = start_serving_server(servable_content, parallel_number=2, device_ids=0, model_file=["tensor_add_2_3.mindir", "tensor_sub_3_2.mindir"]) # Client ys = [] instances = [] for i in range(20): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * (i + 1) x3 = np.array([[3.1, 4.2], [5.3, 6.4]], np.float32) * (i + 1) x4 = np.array([[0.5, 9.6], [8.7, 7.8]], np.float32) * (i + 1) x5 = np.array([[0.2, 9.5], [8.2, 7.1]], np.float32) * (i + 1) instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5}) y = (x1 + x2 + 1) + (x3 - x4 - x5 + 1) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) for i in range(len(instances)): assert is_float_equal(result[i]["y"], ys[i]) assert check_infer_log(base.servable_name, base.version_number, device_id=0, extra_id=None) assert check_infer_log(base.servable_name, base.version_number, device_id=None, extra_id=0) ================================================ FILE: tests/ut/python/tests/test_register_method.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving pipeline with client""" import numpy as np from common import ServingTestBase from common import serving_test, create_client from common import start_serving_server from mindspore_serving import server @serving_test def test_register_method_with_model_success(): """ Feature: test register method Description: method with only python function stage, python function has model.call Expectation: success to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def call_model(x1, x2): y = model.call(x1, x2) return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(call_model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, version_number=1, start_version_number=1) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_register_method_without_add_stage_success(): """ Feature: test register method Description: method without any stages Expectation: success to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names=["x1", "x2"]) def predict(x1, x2): return x1, x2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert len(result) == 3 assert (result[0]["x1"] == x1).all() assert (result[0]["x2"] == x2).all() @serving_test def test_register_method_without_register_method_failed(): """ Feature: test register method Description: without any methods Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "There is no method registered for servable" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_input_more_failed(): """ Feature: test register method Description: model input count not equal to model stage input count Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, x3, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "The inputs count 3 in register_method not equal to the count 2 defined in model" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_input_less_failed(): """ Feature: test register method Description: model input count not equal to model stage input count Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "The inputs count 1 in register_method not equal to the count 2 defined in model" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_input_less2_failed(): """ Feature: test register method Description: model input count not equal to some model stage input count Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "RegisterInputOutputInfo failed, inputs count 1 not match old count 2" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_input_less3_failed(): """ Feature: test register method Description: model input count not equal to model stage input count Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, outputs_count=1) return y @register.register_method(output_names="y") def predict2(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "RegisterInputOutputInfo failed, inputs count 2 not match old count 1" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_with_batch_dim_input_more_failed(): """ Feature: test register method Description: model input count not equal to model stage input count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, x3, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "The inputs count 3 in register_method not equal to the count 2 defined in model" in str(e) @serving_test def test_register_method_two_input_one_output_one_model_stage_with_batch_dim_input_less_failed(): """ Feature: test register method Description: model input count not equal to model stage input count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "The inputs count 1 in register_method not equal to the count 2 defined in model" in str(e) @serving_test def test_register_method_two_input_two_output_one_model_stage_output_more_failed(): """ Feature: test register method Description: model output count not equal to model stage output count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_2_2.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2, y3 = register.add_stage(tensor_add, x1, x2, outputs_count=3) return y1, y2 """ try: start_serving_server(servable_content, model_file="tensor_add_2_2.mindir") assert False except RuntimeError as e: assert "The outputs count 3 in register_method not equal to the count 2 defined in model" in str(e) @serving_test def test_register_method_three_input_two_output_one_model_stage_output_less_failed(): """ Feature: test register method Description: model output count not equal to model stage output count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_2_3.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(tensor_add, x1, x2, outputs_count=2) return y1, y2 """ try: start_serving_server(servable_content, model_file="tensor_add_2_3.mindir") assert False except RuntimeError as e: assert "The outputs count 2 in register_method not equal to the count 3 defined in model" in str(e) @serving_test def test_register_method_three_input_two_output_one_model_stage_output_less2_failed(): """ Feature: test register method Description: model output count not equal to some model stage output count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_2_3.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2, y3 = register.add_stage(tensor_add, x1, x2, outputs_count=3) y1, y2 = register.add_stage(tensor_add, y1, y2, outputs_count=2) return y1, y2 """ try: start_serving_server(servable_content, model_file="tensor_add_2_3.mindir") assert False except RuntimeError as e: assert "RegisterInputOutputInfo failed, outputs count 2 not match old count 3" in str(e) @serving_test def test_register_method_three_input_two_output_one_model_stage_output_less3_failed(): """ Feature: test register method Description: model output count not equal to some model stage output count, with_batch_dim is True Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add_2_3.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2, y3 = register.add_stage(tensor_add, x1, x2, outputs_count=3) return y1, y2 @register.register_method(output_names=["y1", "y2"]) def predict2(x1, x2): y1, y2 = register.add_stage(tensor_add, x1, x2, outputs_count=2) return y1, y2 """ try: start_serving_server(servable_content, model_file="tensor_add_2_3.mindir") assert False except RuntimeError as e: assert "RegisterInputOutputInfo failed, outputs count 2 not match old count 3" in str(e) @serving_test def test_register_method_model_file_repeat_failed(): """ Feature: test register method Description: same model file repeatedly used in diff declare_model Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) tensor_add2 = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names=["y"]) def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "model file 'tensor_add.mindir' has already been used" in str(e) @serving_test def test_register_method_model_file_repeat2_failed(): """ Feature: test register method Description: same model file repeatedly used in diff declare_model Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file=["tensor_add.mindir", "tensor_sub.mindir"], model_format="MindIR") tensor_add2 = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) assert False except RuntimeError as e: assert "model file 'tensor_add.mindir' has already been used" in str(e) @serving_test def test_register_method_model_file_repeat3_failed(): """ Feature: test register method Description: same model file repeatedly used in diff declare_model Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add2 = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") tensor_add = register.declare_model(model_file=["tensor_add.mindir", "tensor_sub.mindir"], model_format="MindIR") @register.register_method(output_names=["y"]) def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file=["tensor_add.mindir", "tensor_sub.mindir"]) assert False except RuntimeError as e: assert "model file 'tensor_add.mindir' has already been used" in str(e) @serving_test def test_register_method_method_registered_repeat_failed(): """ Feature: test register method Description: methods with same name Expectation: failed to start serving server. """ servable_content = r""" from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Method add_cast has been registered more than once." in str(e) @serving_test def test_register_method_input_arg_invalid_failed(): """ Feature: test register method Description: method input args invalid Expectation: failed to start serving server. """ servable_content = r""" from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, **x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "'add_cast' input x2 cannot be VAR_KEYWORD !" in str(e) @serving_test def test_register_method_input_arg_invalid2_failed(): """ Feature: test register method Description: method input args invalid Expectation: failed to start serving server. """ servable_content = r""" from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, *x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "'add_cast' input x2 cannot be VAR_POSITIONAL !" in str(e) @serving_test def test_register_method_function_stage_invalid_input_failed(): """ Feature: test register method Description: stage input args invalid Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def func_test(x1, x2): return x1+1, x2+1 @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.add_stage(func_test, x1, np.ones([2,2]), outputs_count=2) y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Each value of parameter *args is a placeholder for data and" in str(e) @serving_test def test_register_method_function_stage_invalid_input2_failed(): """ Feature: test register method Description: stage input args invalid Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def postprocess(y, data): return y @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(postprocess, y, np.ones([2,2]), outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Each value of parameter *args is a placeholder for data and" in str(e) @serving_test def test_register_method_model_stage_invalid_input_failed(): """ Feature: test register method Description: stage input args invalid Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, np.ones([2,2]), outputs_count=1) return y """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Each value of parameter *args is a placeholder for data and" in str(e) @serving_test def test_register_method_invalid_return_failed(): """ Feature: test register method Description: method return invalid Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y", "data"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y, np.ones([2,2]) """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "Each value returned is a placeholder for data and must come from the method" in str(e) @serving_test def test_register_method_function_stage_batch_input_count_not_same_failed(): """ Feature: test register method Description: function stage input count diff in diff method Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) x1, x2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=4) return y @register.register_method(output_names=["y"]) def add_cast2(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) x1, x2 = register.add_stage(func_test_batch, x1, x2, y, outputs_count=2, batch_size=4) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert f"'{base.servable_name}.func_test_batch' inputs count 3 " \ f"not match last registered count 2" in str(e) @serving_test def test_register_method_function_stage_batch_input_count_not_same2_failed(): """ Feature: test register method Description: function stage input count diff in diff method Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=4) y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y @register.register_method(output_names=["y"]) def add_cast2(x1, x2, x3): x1, x2 = register.add_stage(func_test_batch, x1, x2, x3, outputs_count=2, batch_size=4) y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert f"'{base.servable_name}.func_test_batch' inputs count 3 " \ f"not match last registered count 2" in str(e) @serving_test def test_register_method_function_stage_batch_output_count_not_same_failed(): """ Feature: test register method Description: function stage output count diff in diff method Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) x1, x2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=4) return y @register.register_method(output_names=["y"]) def add_cast2(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) x1, x2, x3 = register.add_stage(func_test_batch, x1, x2, outputs_count=3, batch_size=4) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert f"'{base.servable_name}.func_test_batch' outputs count 3 " \ f"not match last registered count 2" in str(e) @serving_test def test_register_method_function_stage_batch_output_count_not_same2_failed(): """ Feature: test register method Description: function stage output count diff in diff method Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=4) y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y @register.register_method(output_names=["y"]) def add_cast2(x1, x2): x1, x2, x3 = register.add_stage(func_test_batch, x1, x2, outputs_count=3, batch_size=4) y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert f"'{base.servable_name}.func_test_batch' outputs count 3 " \ f"not match last registered count 2" in str(e) @serving_test def test_register_method_method_output_count_not_match_output_names_failed(): """ Feature: test register method Description: outputs count registered not equal to the count return in function Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y, x2 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "Method return output size 2 not match registered 1" in str(e) @serving_test def test_register_method_method_python_function_batch_size_exist_inconsistently_failed(): """ Feature: test register method Description: python function used in multi add_stage, one with batch_size, other without batch_size Expectation: failed to start serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def stage_test_fun(x1, x2): return x1+x2 @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(stage_test_fun, x1, x2, outputs_count=1) return y @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(stage_test_fun, x1, x2, outputs_count=1, batch_size=4) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "parameter 'batch_size' in multiple 'add_stage' should be enabled or disabled consistently" in str(e) ================================================ FILE: tests/ut/python/tests/test_restful_base64_data.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving RESTful, with master, worker and client""" import base64 import numpy as np from common import ServingTestBase, serving_test from common import servable_config_import, servable_config_declare_servable from common_restful import compare_float_value, check_number_result, post_restful from common_restful import start_str_restful_server, start_bytes_restful_server, start_bool_int_float_restful_server from mindspore_serving import server def b64_decode_to_str(a): return bytes.decode(base64.b64decode(a["b64"])) def common_test_restful_base64_str_scalar_input_output_success(shape): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): if shape is None: instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str"} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str"} else: instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str", 'shape': shape} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str", 'shape': shape} result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) result = result["instances"] assert result[0]["text"] == str_a[0] + str_b[0] assert result[1]["text"] == str_a[1] + str_b[1] assert result[2]["text"] == str_a[2] + str_b[2] @serving_test def test_restful_base64_str_scalar_input_output_success(): common_test_restful_base64_str_scalar_input_output_success(shape=None) @serving_test def test_restful_base64_str_scalar_shape1_input_output_success(): common_test_restful_base64_str_scalar_input_output_success(shape=[1]) @serving_test def test_restful_base64_str_scalar_shape_empty_input_output_success(): common_test_restful_base64_str_scalar_input_output_success(shape=[]) @serving_test def test_restful_base64_empty_str_input_output_success(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str"} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str"} result = post_restful("localhost:5500", base.servable_name, "str_empty", instances) result = result["instances"] assert result[0]["text"] == "" assert result[1]["text"] == "456" assert result[2]["text"] == "" @serving_test def test_restful_base64_str_scalar_invalid_shape0_input_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str", "shape": [0]} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str", "shape": [0]} result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "only support scalar when data type is string or bytes, please check 'type' or 'shape'" \ in str(result["error_msg"]) @serving_test def test_restful_base64_str_scalar_invalid_shape_input_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str", 'shape': [2]} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str", 'shape': [2]} result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "json object, only support scalar when data type is string or bytes, please check 'type' or 'shape'" \ in str(result["error_msg"]) @serving_test def test_restful_base64_str_1d_array_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [{"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str"}, {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "type": "str"}] instance["text2"] = [{"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str"}, {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "type": "str"}] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "json array, string or bytes type only support one item" in str(result["error_msg"]) def common_test_restful_bytes_input_output_success(shape): base = start_bytes_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): if shape is not None: instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), "shape": shape} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), "shape": shape} else: instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode()} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode()} result = post_restful("localhost:5500", base.servable_name, "bytes_concat", instances) result = result["instances"] assert b64_decode_to_str(result[0]["text"]) == str_a[0] + str_b[0] assert b64_decode_to_str(result[1]["text"]) == str_a[1] + str_b[1] assert b64_decode_to_str(result[2]["text"]) == str_a[2] + str_b[2] @serving_test def test_restful_bytes_input_output_success(): common_test_restful_bytes_input_output_success(None) @serving_test def test_restful_bytes_empty_shape_success(): common_test_restful_bytes_input_output_success([]) @serving_test def test_restful_bytes_shape1_success(): common_test_restful_bytes_input_output_success([1]) @serving_test def test_restful_empty_bytes_input_output_success(): base = start_bytes_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode()} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode()} result = post_restful("localhost:5500", base.servable_name, "bytes_empty", instances) result = result["instances"] assert b64_decode_to_str(result[0]["text"]) == "" assert b64_decode_to_str(result[1]["text"]) == "456" assert b64_decode_to_str(result[2]["text"]) == "" @serving_test def test_restful_bytes_1d_array_failed(): base = start_bytes_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [{"b64": base64.b64encode(str.encode(str_a[i])).decode()}, {"b64": base64.b64encode(str.encode(str_a[i])).decode()}] instance["text2"] = [{"b64": base64.b64encode(str.encode(str_b[i])).decode()}, {"b64": base64.b64encode(str.encode(str_b[i])).decode()}] result = post_restful("localhost:5500", base.servable_name, "bytes_concat", instances) assert "json array, string or bytes type only support one item" in str(result["error_msg"]) @serving_test def test_restful_bytes_invalid_shape_input_failed(): base = start_bytes_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = {"b64": base64.b64encode(str.encode(str_a[i])).decode(), 'shape': [0]} instance["text2"] = {"b64": base64.b64encode(str.encode(str_b[i])).decode(), 'shape': [0]} result = post_restful("localhost:5500", base.servable_name, "bytes_concat", instances) assert "only support scalar when data type is string or bytes, please check 'type' or 'shape'" \ in result["error_msg"] @serving_test def test_restful_base64_bool_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = np.int8(i % 2 == 0) instance["bool_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "bool"} result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert not result[0]["value"] assert result[1]["value"] assert not result[2]["value"] @serving_test def test_restful_base64_bool_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = [(i % 2 == 0)] * (i + 1) val = np.array(val) instance["bool_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "bool", "shape": [i + 1]} result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert result[0]["value"] == [False] assert result[1]["value"] == [True, True] assert result[2]["value"] == [False, False, False] @serving_test def test_restful_base64_bool_2d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i % 2 == 0) val = [[val] * (i + 1)] * (i + 1) val = np.array(val) instance["bool_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "bool", "shape": [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert result[0]["value"] == [[False]] assert result[1]["value"] == [[True, True], [True, True]] assert result[2]["value"] == [[False, False, False], [False, False, False], [False, False, False]] @serving_test def test_restful_base64_int_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = np.int32(i * 2) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "int32"} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == 1 assert result[1]["value"] == 3 assert result[2]["value"] == 5 @serving_test def test_restful_base64_int_1d_empty_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): if i % 2 == 0: val = [] else: val = [i * 2] * (i + 1) val = np.array(val).astype(np.int32) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "int32", "shape": val.shape} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [] assert result[1]["value"] == [3, 3] assert result[2]["value"] == [] @serving_test def test_restful_base64_int_2d_empty_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): if i % 2 == 0: val = [[]] else: val = [i * 2] * (i + 1) val = np.array(val).astype(np.int32) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "int32", "shape": val.shape} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [[]] assert result[1]["value"] == [3, 3] assert result[2]["value"] == [[]] @serving_test def test_restful_base64_int_2d_empty_invalid_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for _, instance in enumerate(instances): val = [[]] val = np.array(val).astype(np.int32) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "int32", "shape": [1, 2, 0, 1]} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) assert "json object, key is 'shape', invalid shape value" in result["error_msg"] @serving_test def test_restful_base64_int_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2 val = [val] * (i + 1) val = np.array(val).astype(np.int32) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "int32", "shape": val.shape} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [1] assert result[1]["value"] == [3, 3] assert result[2]["value"] == [5, 5, 5] def common_test_restful_base64_int_type_2d_array_input_output_success(dtype): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype_str_map = {np.int8: "int8", np.int16: "int16", np.int32: "int32", np.int64: "int64"} assert dtype in dtype_str_map for i, instance in enumerate(instances): val = (i + 1) * 2 * (-1 if i % 2 == 0 else 1) # -2, 4, -6 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str_map[dtype], "shape": val.shape} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [[-1]] assert result[1]["value"] == [[5, 5], [5, 5]] assert result[2]["value"] == [[-5, -5, -5], [-5, -5, -5], [-5, -5, -5]] @serving_test def test_restful_base64_int8_2d_array_input_output_success(): common_test_restful_base64_int_type_2d_array_input_output_success(np.int8) @serving_test def test_restful_base64_int16_2d_array_input_output_success(): common_test_restful_base64_int_type_2d_array_input_output_success(np.int16) @serving_test def test_restful_base64_int32_2d_array_input_output_success(): common_test_restful_base64_int_type_2d_array_input_output_success(np.int32) @serving_test def test_restful_base64_int64_2d_array_input_output_success(): common_test_restful_base64_int_type_2d_array_input_output_success(np.int64) def common_test_restful_base64_uint_type_2d_array_input_output_success(dtype): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype_str_map = {np.uint8: "uint8", np.uint16: "uint16", np.uint32: "uint32", np.uint64: "uint64"} assert dtype in dtype_str_map for i, instance in enumerate(instances): val = i * 2 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) instance["int_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str_map[dtype], "shape": val.shape} result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [[1]] assert result[1]["value"] == [[3, 3], [3, 3]] assert result[2]["value"] == [[5, 5, 5], [5, 5, 5], [5, 5, 5]] @serving_test def test_restful_base64_uint8_2d_array_input_output_success(): common_test_restful_base64_uint_type_2d_array_input_output_success(np.uint8) @serving_test def test_restful_base64_uint16_2d_array_input_output_success(): common_test_restful_base64_uint_type_2d_array_input_output_success(np.uint16) @serving_test def test_restful_base64_uint32_2d_array_input_output_success(): common_test_restful_base64_uint_type_2d_array_input_output_success(np.uint32) @serving_test def test_restful_base64_uint64_2d_array_input_output_success(): common_test_restful_base64_uint_type_2d_array_input_output_success(np.uint64) @serving_test def test_restful_base64_float_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = np.float32(i * 2.2) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp32"} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) result = result["instances"] assert result[0]["value"] == 1.0 assert abs(result[1]["value"] - (2.2 + 1)) < 0.001 assert abs(result[2]["value"] - (4.4 + 1)) < 0.001 @serving_test def test_restful_base64_float_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] y_data_list = [] for i, instance in enumerate(instances): val = [i * 2.2 * (-1 if i % 2 == 0 else 1)] * (i + 1) # [0], [2.2, 2.2], [-4.4, -4.4, -4.4] val = np.array(val).astype(np.float32) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp32", 'shape': [i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) check_number_result(result, y_data_list, "value") def common_test_restful_base64_float_type_2d_array_input_output_success(dtype, dtype_str=None): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map if dtype_str is None: dtype_str = dtype_str_map[dtype] y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str, 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) check_number_result(result, y_data_list, "value") @serving_test def test_restful_base64_float16_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float16) @serving_test def test_restful_base64_float32_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float32) @serving_test def test_restful_base64_float64_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float64) @serving_test def test_restful_base64_float16_2_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float16, "float16") @serving_test def test_restful_base64_float32_2_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float32, "float32") @serving_test def test_restful_base64_float64_2_2d_array_input_output_success(): common_test_restful_base64_float_type_2d_array_input_output_success(np.float64, "float64") @serving_test def test_restful_base64_mix_all_type_success(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def func_test(bool_val, int_val, float_val, str_val, bytes_val): return ~bool_val, int_val+1, float_val+1, str_val+"123", str.encode(bytes.decode(bytes_val.tobytes()) + "456") @register.register_method(output_names=['bool_val', 'int_val', 'float_val', 'str_val', 'bytes_val']) def mix_all_type(bool_val, int_val, float_val, str_val, bytes_val): bool_val, int_val, float_val, str_val, bytes_val = \ register.add_stage(func_test, bool_val, int_val, float_val, str_val, bytes_val, outputs_count=5) return bool_val, int_val, float_val, str_val, bytes_val """ base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): float_val = np.array([2.2, 3.3]).astype(np.float32) instance["float_val"] = {"b64": base64.b64encode(float_val.tobytes()).decode(), 'type': "fp32", 'shape': [2]} int_val = np.array([2, 3]).astype(np.int32) instance["int_val"] = {"b64": base64.b64encode(int_val.tobytes()).decode(), 'type': "int32", 'shape': [2]} bool_val = np.array([True, False]) instance["bool_val"] = {"b64": base64.b64encode(bool_val.tobytes()).decode(), 'type': "bool", 'shape': [2]} str_val = "ABC" instance["str_val"] = {"b64": base64.b64encode(str.encode(str_val)).decode(), 'type': "str", 'shape': []} bytes_val = "DEF" instance["bytes_val"] = {"b64": base64.b64encode(str.encode(bytes_val)).decode(), 'type': "bytes", 'shape': []} result = post_restful("localhost:5500", base.servable_name, "mix_all_type", instances) result = result["instances"] for i in range(3): compare_float_value(result[i]["float_val"], [3.2, 4.3]) assert result[i]["int_val"] == [3, 4] assert result[i]["bool_val"] == [False, True] assert result[i]["str_val"] == "ABC123" assert b64_decode_to_str(result[i]["bytes_val"]) == "DEF456" @serving_test def test_restful_base64_without_b64_key_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {'type': dtype_str_map[dtype], 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "'b64' should be specified only one time" in result["error_msg"] @serving_test def test_restful_base64_b64_invalid_type_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {'b64': 123, 'type': dtype_str_map[dtype], 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "get scalar data failed, type is string, but json is not string type" in result["error_msg"] @serving_test def test_restful_base64_b64_invalid_value_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) b64_val = base64.b64encode(val.tobytes()).decode() b64_val = '+==+==' + b64_val[:len('+==+==')] instance["float_val"] = {'b64': b64_val, 'type': dtype_str_map[dtype], 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "is illegal b64 encode string" in result["error_msg"] @serving_test def test_restful_base64_b64_value_empty_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map for i, instance in enumerate(instances): instance["float_val"] = {'b64': "", 'type': dtype_str_map[dtype], 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "decode base64 size:0; Given info: type:float16; type size:2; element nums:1" in result["error_msg"] @serving_test def test_restful_base64_dtype_unknow_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "dtype_unknow", 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, json object, specified type:'dtype_unknow' is illegal" in result["error_msg"] @serving_test def test_restful_base64_dtype_empty_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "", 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, json object, specified type:'' is illegal" in result["error_msg"] @serving_test def test_restful_base64_dtype_invalid_type_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': 1, 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "json object, key is 'type', value should be string type" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_dtype_not_match_empty_data_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = [[]] val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_dtype_not_match_size_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp32", 'shape': [i + 2, i + 2]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_shape_large_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 1)] * (i + 1) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str_map[dtype], 'shape': [i + 2, i + 2]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_shape_small_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str_map[dtype], 'shape': [i + 1, i + 1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_shape_small2_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 dtype_str_map = {np.float16: "fp16", np.float32: "fp32", np.float64: "fp64"} assert dtype in dtype_str_map y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': dtype_str_map[dtype], 'shape': [i + 2, i]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_empty_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", 'shape': []} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_none_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16"} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "Parser request failed, size is not matched" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_invalid_2d_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", "shape": [[]]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "json object, key is 'shape', array value should be unsigned integer" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_invalid_shape_str_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", "shape": ["abc"]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "json object, key is 'shape', array value should be unsigned integer" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_float_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", "shape": [1.1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "json object, key is 'shape', array value should be unsigned integer" in result["error_msg"] @serving_test def test_restful_base64_float16_2d_array_negative_shape_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] dtype = np.float16 y_data_list = [] for i, instance in enumerate(instances): val = i * 2.2 * (-1 if i % 2 == 0 else 1) # 0, 2.2 ,-4.4 val = [[val] * (i + 2)] * (i + 2) val = np.array(val).astype(dtype) y_data_list.append(val + 1) instance["float_val"] = {"b64": base64.b64encode(val.tobytes()).decode(), 'type': "fp16", "shape": [-1]} result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) assert "json object, key is 'shape', array value should be unsigned integer" in result["error_msg"] ================================================ FILE: tests/ut/python/tests/test_restful_json_data.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving RESTful, with master, worker and client""" import numpy as np from common import serving_test from common_restful import compare_float_value, post_restful from common_restful import start_str_restful_server, start_bool_int_float_restful_server @serving_test def test_restful_str_scalar_input_output_success(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str_a[i] instance["text2"] = str_b[i] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) result = result["instances"] assert result[0]["text"] == str_a[0] + str_b[0] assert result[1]["text"] == str_a[1] + str_b[1] assert result[2]["text"] == str_a[2] + str_b[2] @serving_test def test_restful_str_scalar_shape1_input_output_success(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [str_a[i]] instance["text2"] = [str_b[i]] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) result = result["instances"] assert result[0]["text"] == str_a[0] + str_b[0] assert result[1]["text"] == str_a[1] + str_b[1] assert result[2]["text"] == str_a[2] + str_b[2] @serving_test def test_restful_empty_str_input_output_success(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = str_a[i] instance["text2"] = str_b[i] result = post_restful("localhost:5500", base.servable_name, "str_empty", instances) result = result["instances"] assert result[0]["text"] == "" assert result[1]["text"] == "456" assert result[2]["text"] == "" @serving_test def test_restful_str_2d_array_one_item_input_output_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [[str_a[i]]] instance["text2"] = [[str_b[i]]] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "bytes or string type input shape can only be (1,) or empty, but given shape is [1, 1]" \ in result["error_msg"] @serving_test def test_restful_str_1d_array_input_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [str_a[i], str_a[i]] instance["text2"] = [str_b[i], str_b[i]] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "json array, string or bytes type only support one item" in str(result["error_msg"]) @serving_test def test_restful_str_invalid_array_input_failed(): base = start_str_restful_server() # Client instances = [{}, {}, {}] str_a = ["ABC", "DEF", "HIJ"] str_b = ["123", "456", "789"] for i, instance in enumerate(instances): instance["text1"] = [str_a[i], [str_a[i]]] instance["text2"] = [str_b[i], [str_b[i]]] result = post_restful("localhost:5500", base.servable_name, "str_concat", instances) assert "json array, string or bytes type only support one item" in str(result["error_msg"]) @serving_test def test_restful_str_invalid_str_message_failed(): base = start_str_restful_server() # Client post_payload = np.array([1.1, 2.2], np.float32).tobytes() result = post_restful("localhost:5500", base.servable_name, "str_concat", None, post_payload=post_payload) assert "Illegal JSON format" in str(result["error_msg"]) @serving_test def test_restful_bool_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): instance["bool_val"] = (i % 2 == 0) result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert not result[0]["value"] assert result[1]["value"] assert not result[2]["value"] @serving_test def test_restful_bool_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): instance["bool_val"] = [(i % 2 == 0)] * (i + 1) result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert result[0]["value"] == [False] assert result[1]["value"] == [True, True] assert result[2]["value"] == [False, False, False] @serving_test def test_restful_bool_2d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = (i % 2 == 0) val = [[val] * (i + 1)] * (i + 1) instance["bool_val"] = val result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) result = result["instances"] assert result[0]["value"] == [[False]] assert result[1]["value"] == [[True, True], [True, True]] assert result[2]["value"] == [[False, False, False], [False, False, False], [False, False, False]] @serving_test def test_restful_bool_invalid_array_array_scalar_mix_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[False], True] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "invalid json array: json type is not array" in result['error_msg'] @serving_test def test_restful_bool_invalid_array2_scalar_array_mix_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [False, [True]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, data should be number, bool, string or bytes" in result['error_msg'] @serving_test def test_restful_bool_invalid_array3_array_dim_not_match_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[False, True], [True]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "invalid json array: json size is 1, the dim 1 expected to be 2" in result['error_msg'] @serving_test def test_restful_bool_invalid_array4_array_dim_not_match_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[[False, True]], [[True]]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "invalid json array: json size is 1, the dim 2 expected to be 2" in result['error_msg'] @serving_test def test_restful_int_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2 instance["int_val"] = val result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == 1 assert result[1]["value"] == 3 assert result[2]["value"] == 5 @serving_test def test_restful_int_empty_input_output_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): if i % 2 == 0: val = [] else: val = [i * 2] * (i + 1) instance["int_val"] = val result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) assert "json array, shape is empty" in result["error_msg"] @serving_test def test_restful_int_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2 val = [val] * (i + 1) instance["int_val"] = val result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [1] assert result[1]["value"] == [3, 3] assert result[2]["value"] == [5, 5, 5] @serving_test def test_restful_int_2d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2 val = [[val] * (i + 1)] * (i + 1) instance["int_val"] = val result = post_restful("localhost:5500", base.servable_name, "int_plus_1", instances) result = result["instances"] assert result[0]["value"] == [[1]] assert result[1]["value"] == [[3, 3], [3, 3]] assert result[2]["value"] == [[5, 5, 5], [5, 5, 5], [5, 5, 5]] @serving_test def test_restful_float_scalar_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2.2 instance["float_val"] = val result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) result = result["instances"] compare_float_value(result[0]["value"], 1.0) compare_float_value(result[1]["value"], 2.2 + 1) compare_float_value(result[2]["value"], 4.4 + 1) @serving_test def test_restful_float_1d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = [i * 2.2] * (i + 1) instance["float_val"] = val result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) result = result["instances"] compare_float_value(result[0]["value"], [1.0]) compare_float_value(result[1]["value"], [3.2, 3.2]) compare_float_value(result[2]["value"], [5.4, 5.4, 5.4]) @serving_test def test_restful_float_2d_array_input_output_success(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for i, instance in enumerate(instances): val = i * 2.2 val = [[val] * (i + 1)] * (i + 1) instance["float_val"] = val result = post_restful("localhost:5500", base.servable_name, "float_plus_1", instances) result = result["instances"] compare_float_value(result[0]["value"], [[1.0]]) compare_float_value(result[1]["value"], [[3.2, 3.2], [3.2, 3.2]]) compare_float_value(result[2]["value"], [[5.4, 5.4, 5.4], [5.4, 5.4, 5.4], [5.4, 5.4, 5.4]]) @serving_test def test_restful_mix_bool_int_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[False, True], [1, 1]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_bool_int2_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[False, 1]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_float_int_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1.1, 1.2], [1, 1]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_float_int2_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1.1, 1]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_int_float_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1, 1], [1.1, 1.2]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_int_float2_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1, 1.2]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_str_float_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [["a", "b"], [1.1, 1.2]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "string or bytes type only support one item" in result['error_msg'] @serving_test def test_restful_mix_str_float2_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [["a", 1.2]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "string or bytes type only support one item" in result['error_msg'] @serving_test def test_restful_mix_float_str_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1.1, 1.2], ["a", "b"]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_float_str2_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[1.1, "b"]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, elements type is not equal" in result['error_msg'] @serving_test def test_restful_mix_bytes_str_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[{"b64": ""}, {"b64": ""}], ["a", "b"]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "string or bytes type only support one item" in result['error_msg'] @serving_test def test_restful_mix_bytes_bool_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[{"b64": ""}, {"b64": ""}], [True, False]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "string or bytes type only support one item" in result['error_msg'] @serving_test def test_restful_mix_bool_bytes_input_failed(): base = start_bool_int_float_restful_server() # Client instances = [{}, {}, {}] for instance in instances: instance["bool_val"] = [[True, False], [{"b64": ""}, {"b64": ""}]] result = post_restful("localhost:5500", base.servable_name, "bool_not", instances) assert "json array, data should be number, bool, string or bytes" in result['error_msg'] ================================================ FILE: tests/ut/python/tests/test_restful_request.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving RESTful, with master, worker and client""" import json import requests import numpy as np from common import ServingTestBase, serving_test, generate_cert from common_restful import create_multi_instances_fp32, create_multi_instances_with_batch_fp32 from common_restful import check_number_result, post_restful from mindspore_serving import server @serving_test def test_restful_request_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances) check_number_result(result, y_data_list) @serving_test def test_https_one_way_auth_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert() server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=False) server.start_restful_server("0.0.0.0:5500", ssl_config=ssl_config) # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("0.0.0.0:5500", base.servable_name, "add_common", instances, https=True) check_number_result(result, y_data_list) @serving_test def test_https_mutual_auth_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=True) server.start_restful_server("0.0.0.0:5500", ssl_config=ssl_config) # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("127.0.0.1:5500", base.servable_name, "add_common", instances, https=True) check_number_result(result, y_data_list) @serving_test def test_https_client_auth_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=False) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) # Client instance_count = 3 data = create_multi_instances_fp32(instance_count) result = post_restful("127.0.0.1:5500", base.servable_name, "add_common", data[0], verify="client.crt", https=True) print(result) assert "post failed" in result @serving_test def test_https_missing_cert_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=True) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) # Client instance_count = 3 data = create_multi_instances_fp32(instance_count) result = post_restful("127.0.0.1:5500", base.servable_name, "add_common", data[0], cert=None, https=True) print(result) assert "post failed" in result @serving_test def test_https_unmatched_cert_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="client.key", custom_ca="ca.crt", verify_client=False) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) try: server.start_restful_server("127.0.0.1:5500", ssl_config=ssl_config) assert False except RuntimeError as e: assert "Serving Error: load private_key from client.key failed" in str(e) @serving_test def test_restful_request_multi_times_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") for instance_count in range(1, 5): instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances) check_number_result(result, y_data_list) @serving_test def test_restful_request_multi_times_int32_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") for instance_count in range(1, 5): instances = [] # instance 1 y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.int32) * (i + 1) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.int32) * (i + 1) y_data_list.append((x1 + x2).astype(np.float32)) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) result = post_restful("localhost:5500", base.servable_name, "add_cast", instances) check_number_result(result, y_data_list) @serving_test def test_restful_request_servable_invalid_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name + "_error", "add_common", instances) assert "servable is not available" in str(result["error_msg"]) @serving_test def test_restful_request_method_invalid_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common" + "_error", instances) assert "method is not available" in str(result["error_msg"]) @serving_test def test_restful_request_with_version_number_0_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances, 0) check_number_result(result, y_data_list) @serving_test def test_restful_request_with_version_number_1_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances, 1) check_number_result(result, y_data_list) @serving_test def test_restful_request_with_version_number_2_invalid_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances, 2) assert "servable is not available" in str(result["error_msg"]) @serving_test def test_restful_request_version_number_negative_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_common", instances, -1) assert "please check url, version number range failed" in str(result["error_msg"]) @serving_test def test_restful_request_without_model_invalid_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) instances_map = {"instances": instances} post_payload = json.dumps(instances_map) print("request:", post_payload) request_url = "http://localhost:5500/x/:add_common" result = requests.post(request_url, data=post_payload) print("result", result.text) result = json.loads(result.text) assert "please check url, the keyword:[model] must contain" in str(result["error_msg"]) @serving_test def test_restful_request_without_method_invalid_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) instances_map = {"instances": instances} post_payload = json.dumps(instances_map) print("request:", post_payload) request_url = f"http://localhost:5500/model/{base.servable_name}" result = requests.post(request_url, data=post_payload) print("result", result.text) result = json.loads(result.text) assert "please check url, the keyword:[service method] must contain." in str(result["error_msg"]) @serving_test def test_restful_request_servable_version_reverse_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:5500") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) instances_map = {"instances": instances} post_payload = json.dumps(instances_map) print("request:", post_payload) request_url = f"http://localhost:5500/version/0/model/{base.servable_name}:add_common" result = requests.post(request_url, data=post_payload) print("result", result.text) result = json.loads(result.text) check_number_result(result, y_data_list) @serving_test def test_restful_request_preprocess_raise_exception_with_batch_failed(): base = ServingTestBase() servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def add_trans_datatype(x1, x2): raise RuntimeError("invalid preprocess") @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.add_stage(add_trans_datatype, x1, x2, outputs_count=2, tag="Preprocess") # cast input to float32 y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") # Client instance_count = 12 instances, _ = create_multi_instances_with_batch_fp32(instance_count) result = post_restful("localhost:5500", base.servable_name, "add_cast", instances) print(result) assert "Preprocess Failed" in str(result["error_msg"]) @serving_test def test_restful_request_larger_than_server_receive_max_size(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500", max_msg_mb_size=1) # 1MB # Client instances = [] x1 = np.ones([1024, 1024], np.float32) x2 = np.ones([1024, 1024], np.float32) instances.append({"x1": x1.tolist(), "x2": x2.tolist()}) # more than 1MB msg result = post_restful("localhost:5500", base.servable_name + "_error", "add_common", instances) print(result) assert "http message is bigger than 1048576" in str(result["error_msg"]) ================================================ FILE: tests/ut/python/tests/test_server_client.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving with master, worker and client""" import os import time import signal import psutil import numpy as np from common import ServingTestBase, serving_test, create_client, generate_cert from common import servable_config_import, servable_config_declare_servable, servable_config_preprocess_cast from common import servable_config_method_add_common, servable_config_method_add_cast from common import start_serving_server from mindspore_serving import server from mindspore_serving.client import SSLConfig def create_multi_instances_fp32(instance_count): instances = [] # instance 1 y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) return instances, y_data_list def check_result(result, y_data_list): assert len(result) == len(y_data_list) for result_item, y_data in zip(result, y_data_list): assert (result_item["y"] == y_data).all() def is_float_equal(left, right): return (np.abs(left - right) < 0.00001).all() @serving_test def test_grpc_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client client = create_client("localhost:5500", base.servable_name, "add_common") instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) check_result(result, y_data_list) @serving_test def test_grpc_multi_times_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client client = create_client("localhost:5500", base.servable_name, "add_common") for instance_count in range(1, 5): instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) check_result(result, y_data_list) @serving_test def test_grpc_async_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client client = create_client("localhost:5500", base.servable_name, "add_common") instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result_future = client.infer_async(instances) result = result_future.result() print(result) check_result(result, y_data_list) @serving_test def test_grpc_async_multi_times_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client, use with avoid affecting the next use case client = create_client("localhost:5500", base.servable_name, "add_common") for instance_count in range(1, 5): instances, y_data_list = create_multi_instances_fp32(instance_count) result_future = client.infer_async(instances) result = result_future.result() check_result(result, y_data_list) @serving_test def test_grpc_start_grpc_twice_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") try: server.start_grpc_server("0.0.0.0:4500") assert False except RuntimeError as e: assert "Serving Error: Serving gRPC server is already running" in str(e) @serving_test def test_grpc_start_restful_server_twice_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_restful_server("0.0.0.0:5500") try: server.start_restful_server("0.0.0.0:4500") assert False except RuntimeError as e: assert "Serving Error: RESTful server is already running" in str(e) @serving_test def test_grpc_alone_repeat_grpc_and_restful_port_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_grpc_server("0.0.0.0:7600") try: server.start_restful_server("0.0.0.0:7600") assert False except RuntimeError as e: assert "Serving Error: RESTful server start failed, bind to the socket address 0.0.0.0:7600 failed" in str(e) @serving_test def test_grpc_alone_repeat_grpc_and_restful_port2_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_restful_server("0.0.0.0:7600") try: server.start_grpc_server("0.0.0.0:7600") assert False except RuntimeError as e: assert "Serving Error: Serving gRPC server start failed, create server failed, address" in str(e) @serving_test def test_grpc_servable_content_success(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += servable_config_method_add_common servable_content += servable_config_method_add_cast base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) check_result(result, y_data_list) @serving_test def test_grpc_one_way_auth_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert() ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=False) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500", ssl_config=ssl_config) ssl_config = SSLConfig(custom_ca="ca.crt") client = create_client("0.0.0.0:5500", base.servable_name, "add_common", ssl_config=ssl_config) instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) check_result(result, y_data_list) @serving_test def test_grpc_mutual_auth_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=True) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) ssl_config = SSLConfig(certificate="client.crt", private_key="client.key", custom_ca="ca.crt") client = create_client("127.0.0.1:5500", base.servable_name, "add_common", ssl_config=ssl_config) instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) check_result(result, y_data_list) @serving_test def test_grpc_client_auth_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=False) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) ssl_config = SSLConfig(custom_ca="client.crt") client = create_client("127.0.0.1:5500", base.servable_name, "add_common", ssl_config=ssl_config) instance_count = 3 data = create_multi_instances_fp32(instance_count) result = client.infer(data[0]) print(result) assert "unavailable" in result["error"] @serving_test def test_grpc_missing_cert_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.key", custom_ca="ca.crt", verify_client=True) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) ssl_config = SSLConfig(custom_ca="ca.crt") client = create_client("127.0.0.1:5500", base.servable_name, "add_common", ssl_config=ssl_config) instance_count = 3 data = create_multi_instances_fp32(instance_count) result = client.infer(data[0]) print(result) assert "unavailable" in result["error"] @serving_test def test_grpc_unmatched_cert_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") generate_cert(server_ip="127.0.0.1") ssl_config = server.SSLConfig(certificate="server.crt", private_key="server.crt", custom_ca="ca.crt", verify_client=True) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) try: server.start_grpc_server("127.0.0.1:5500", ssl_config=ssl_config) assert False except RuntimeError as e: assert "Serving Error: Serving gRPC server start failed, create server failed, address" in str(e) @serving_test def test_grpc_preprocess_outputs_count_not_match_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def add_trans_datatype(x1, x2): return x1.astype(np.float32) @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32 y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) client = create_client("localhost:5500", base.servable_name, "add_cast") result = client.infer(instances) print(result) assert "Preprocess Failed" in str(result["error"]) or "servable is not available" in str(result["error"]) @serving_test def test_grpc_postprocess_outputs_count_not_match_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def add_trans_datatype(x1, x2): return x1.astype(np.float32) @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) y, y2 = register.call_postprocess(add_trans_datatype, y, x2) return y """ base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) client = create_client("localhost:5500", base.servable_name, "add_cast") result = client.infer(instances) print(result) assert "Postprocess Failed" in str(result["error"]) or "servable is not available" in str(result["error"]) @serving_test def test_grpc_preprocess_update_numpy_success(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" def preprocess(x3): x3[0] = 123 return x3 def postprocess(x3, x4): return x3 + 1, x4 + 2 @register.register_method(output_names=["x3", "x4"]) def add_cast(x1, x2, x3): x4 = register.call_preprocess(preprocess, x3) # [123, 1, 1], expect x3 is x4, same as python function call y = register.call_servable(x1, x2) x3, x4 = register.call_postprocess(postprocess, x3, x4) return x3, x4 """ base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instances = [{}, {}, {}] for instance in instances: instance["x1"] = np.ones([2, 2]).astype(np.float32) instance["x2"] = np.ones([2, 2]).astype(np.float32) instance["x3"] = np.ones([3]).astype(np.int32) # Client, use with avoid affecting the next use case client = create_client("localhost:5500", base.servable_name, "add_cast") result = client.infer(instances) print(result) x3 = (np.array([123, 1, 1]) + 1).tolist() x4 = (np.array([123, 1, 1]) + 2).tolist() assert result[0]["x3"].tolist() == x3 assert result[0]["x4"].tolist() == x4 assert result[1]["x3"].tolist() == x3 assert result[1]["x4"].tolist() == x4 assert result[2]["x3"].tolist() == x3 assert result[2]["x4"].tolist() == x4 @serving_test def test_grpc_larger_than_server_receive_max_size(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500", max_msg_mb_size=1) # 1MB # Client client = create_client("localhost:5500", base.servable_name, "add_common") instances = [] # instance 1 y_data_list = [] x1 = np.ones([1024, 1024], np.float32) x2 = np.ones([1024, 1024], np.float32) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) result = client.infer(instances) # more than 1MB msg print(result) assert "Grpc Error, (8, 'resource exhausted')" in str(result["error"]) @serving_test def test_server_client_input_param_less(): # fail returned from Worker::RunAsync base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_method_add_common base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instance_count = 3 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x3": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) assert "Cannot find input x1 in instance input" in result["error"] @serving_test def test_server_client_servable_not_available(): # fail returned from Worker::RunAsync base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_method_add_common base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client instance_count = 3 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x3": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name + "error", "add_common") result = client.infer(instances) print(result) assert "servable is not available" in result["error"] @serving_test def test_server_client_max_request_count(): # fail returned from Worker::RunAsync base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += r""" import time def preprocess(x1, x2): time.sleep(1) return x1, x2 @register.register_method(output_names=["y"]) def add_common(x1, x2): x1, x2 = register.call_preprocess(preprocess, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) server.master.context.set_max_enqueued_requests(1) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) instance = {"x1": x1, "x2": x2} client = create_client("localhost:5500", base.servable_name, "add_common") result_list = [] for _ in range(2): result = client.infer_async(instance) result_list.append(result) result0 = result_list[0].result() result1 = result_list[1].result() print(result0) print(result1) assert "error" in result0 or "error" in result1 if "error" in result0: assert "error" not in result1 assert "Serving Error: enqueued requests count exceeds the limit 1" in result0["error"] else: assert "error" not in result0 assert "Serving Error: enqueued requests count exceeds the limit 1" in result1["error"] @serving_test def test_server_client_one_model_stage_with_batch_dim_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[3.3, 4.4]], np.float32) x2 = np.array([[7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert (result[1]["y"] == y).all() assert (result[2]["y"] == y).all() @serving_test def test_server_client_one_model_stage_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert (result[1]["y"] == y).all() assert (result[2]["y"] == y).all() @serving_test def test_server_client_with_batch_dim_data_size_invalid_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[3.3, 4.4]], np.float32) x2 = np.array([[7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}] instances[1]["x2"] = np.array([[7.7, 8.8, 9.9]], np.float32) print(instances) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert "Given model input 1 size 12 not match the size 8 defined in model" in result[1]["error"] assert (result[2]["y"] == y).all() @serving_test def test_server_client_with_batch_dim_data_type_invalid_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[3.3, 4.4]], np.float32) x2 = np.array([[7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}] instances[1]["x2"] = np.array([[7.7, 9.9]], np.int32) print(instances) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert "Given model input 1 data type kMSI_Int32 not match the data type kMSI_Float32 defined in model" in \ result[1]["error"] assert (result[2]["y"] == y).all() @serving_test def test_server_client_data_size_invalid_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}] instances[1]["x2"] = np.array([[5.5, 6.6, 8.8], [7.7, 8.8, 9.9]], np.float32) print(instances) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert "Given model input 1 size 24 not match the size 16 defined in model" in result[1]["error"] assert (result[2]["y"] == y).all() @serving_test def test_server_client_data_type_invalid_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}, {"x1": x1, "x2": x2}] instances[1]["x2"] = np.array([[5.5, 6.8], [7.7, 9.9]], np.int32) print(instances) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() assert "Given model input 1 data type kMSI_Int32 not match the data type kMSI_Float32 defined in model" in \ result[1]["error"] assert (result[2]["y"] == y).all() @serving_test def test_server_client_two_model_stage_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_one_model_stage_with_function_front_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + 1 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_one_model_stage_with_function_tail_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) y = register.add_stage(add_test, y, x3, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + 1 instances.append({"x1": x1, "x2": x2, "x3": x3}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_one_model_stage_with_function_front_and_tail_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) y = register.add_stage(add_test, y, x4, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[3.5, 4.3], [5.2, 6.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + x4 + 2 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_one_model_stage_with_function_front_and_tail_double_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5, x6): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(add_test, y, x3, outputs_count=1) y = register.add_stage(tensor_add, y, x4, outputs_count=1) y = register.add_stage(add_test, y, x5, outputs_count=1) y = register.add_stage(add_test, y, x6, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[3.5, 4.3], [5.2, 6.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[1.5, 2.3], [3.2, 4.4]], np.float32) * 1.1 * (i + 1) x6 = np.array([[5.5, 6.3], [7.2, 8.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + x4 + x5 + x6 + 4 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5, "x6": x6}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_two_model_stage_with_function_front_and_tail_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5, x6): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) y = register.add_stage(add_test, y, x4, outputs_count=1) y = register.add_stage(tensor_add, y, x5, outputs_count=1) y = register.add_stage(add_test, y, x6, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[8.5, 7.3], [6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[3.5, 4.3], [5.2, 6.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[1.5, 2.3], [3.2, 4.4]], np.float32) * 1.1 * (i + 1) x6 = np.array([[5.5, 6.3], [7.2, 8.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + x4 + x5 + x6 + 3 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5, "x6": x6}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_two_model_stage_with_function_front_and_tail_with_batch_dim_success(): servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def add_test(x1, x2): return x1 + x2 + 1 @register.register_method(output_names="y") def predict(x1, x2, x3, x4, x5, x6): y = register.add_stage(add_test, x1, x2, outputs_count=1) y = register.add_stage(tensor_add, y, x3, outputs_count=1) y = register.add_stage(add_test, y, x4, outputs_count=1) y = register.add_stage(tensor_add, y, x5, outputs_count=1) y = register.add_stage(add_test, y, x6, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[6.2, 5.4]], np.float32) * 1.1 * (i + 1) x4 = np.array([[5.2, 6.4]], np.float32) * 1.1 * (i + 1) x5 = np.array([[3.2, 4.4]], np.float32) * 1.1 * (i + 1) x6 = np.array([[7.2, 8.4]], np.float32) * 1.1 * (i + 1) y = x1 + x2 + x3 + x4 + x5 + x6 + 3 instances.append({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5, "x6": x6}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_server_client_worker_exit_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client client = create_client("localhost:5500", base.servable_name, "add_common") instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) check_result(result, y_data_list) cur_process = psutil.Process(os.getpid()) children = cur_process.children(recursive=False) for item in children: os.kill(item.pid, signal.SIGINT) time.sleep(2) result = client.infer(instances) print(result) assert "Grpc Error, (14, 'unavailable')" in result["error"] @serving_test def test_server_client_worker_kill_restart_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") # Client client = create_client("localhost:5500", base.servable_name, "add_common") instance_count = 3 instances, y_data_list = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) check_result(result, y_data_list) cur_process = psutil.Process(os.getpid()) children = cur_process.children(recursive=False) for item in children: os.kill(item.pid, signal.SIGKILL) time.sleep(3) result = client.infer(instances) print(result) check_result(result, y_data_list) @serving_test def test_server_client_worker_kill_no_restart_success(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) server.start_grpc_server("0.0.0.0:5500") cur_process = psutil.Process(os.getpid()) children = cur_process.children(recursive=False) for item in children: os.kill(item.pid, signal.SIGKILL) time.sleep(3) # Client client = create_client("localhost:5500", base.servable_name, "add_common") instance_count = 3 instances, _ = create_multi_instances_fp32(instance_count) result = client.infer(instances) print(result) assert "Grpc Error, (14, 'unavailable')" in result["error"] @serving_test def test_start_server_invalid_grpc_address_failed(): try: server.start_grpc_server("invalid address") assert False except RuntimeError as e: assert "The format of the Serving gRPC address 'invalid address' is illegal" in str(e) @serving_test def test_start_server_invalid_grpc_address2_failed(): try: server.start_grpc_server("127.0.0.1") assert False except RuntimeError as e: assert "The format of the Serving gRPC address '127.0.0.1' is illegal" in str(e) @serving_test def test_start_server_invalid_grpc_address3_failed(): try: server.start_grpc_server("127.0.0.0.1:5000") assert False except RuntimeError as e: assert "Serving gRPC server start failed, create server failed, address 127.0.0.0.1:5000" in str(e) @serving_test def test_start_server_invalid_grpc_address4_failed(): try: server.start_grpc_server("127.0.0.1:5000000") assert False except RuntimeError as e: assert "The port of the Serving gRPC address '127.0.0.1:5000000' is out of legal range [1 ~ 65535]" in str(e) @serving_test def test_start_server_invalid_grpc_address5_failed(): try: server.start_grpc_server("unix:") assert False except RuntimeError as e: assert "Empty grpc server unix domain socket address" in str(e) @serving_test def test_start_server_invalid_grpc_address6_failed(): try: server.start_grpc_server("127.0.256.1:5000") assert False except RuntimeError as e: assert "Serving gRPC server start failed, create server failed, address 127.0.256.1:5000" in str(e) @serving_test def test_start_server_invalid_grpc_address7_failed(): try: server.start_grpc_server("127.0.0.1:5000:5000") assert False except RuntimeError as e: assert "Serving gRPC server start failed, create server failed, address 127.0.0.1:5000:5000" in str(e) @serving_test def test_start_server_invalid_restful_address_failed(): try: server.start_restful_server("invalid address") assert False except RuntimeError as e: assert "The format of the RESTful server address 'invalid address' is illegal" in str(e) @serving_test def test_start_server_invalid_restful_address2_failed(): try: server.start_restful_server("127.0.0.1") assert False except RuntimeError as e: assert "The format of the RESTful server address '127.0.0.1' is illegal" in str(e) @serving_test def test_start_server_invalid_restful_address3_failed(): try: server.start_restful_server("127.0.0.0.1:5000") assert False except RuntimeError as e: assert "RESTful server start failed, bind to the socket address 127.0.0.0.1:5000 failed" in str(e) @serving_test def test_start_server_invalid_restful_address4_failed(): try: server.start_restful_server("127.0.0.1:5000000") assert False except RuntimeError as e: assert "The port of the RESTful server address '127.0.0.1:5000000' is out of legal range [1 ~ 65535]" in str(e) @serving_test def test_start_server_invalid_restful_address5_failed(): try: server.start_restful_server("127.0.256.1:5000") assert False except RuntimeError as e: assert "RESTful server start failed, bind to the socket address 127.0.256.1:5000 failed" in str(e) @serving_test def test_start_server_invalid_restful_address6_failed(): try: server.start_restful_server("unix:address_temp") assert False except RuntimeError as e: assert "RESTful server does not support binding to unix domain socket" in str(e) ================================================ FILE: tests/ut/python/tests/test_serving_log.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import sys import subprocess from common import serving_test def start_new_log_process(log_py_context, env_set): """start new process with log""" with open("test_log.py", "w") as fp: fp.write(log_py_context) log_file = os.path.join(os.getcwd(), "test_log.py") log_text = os.path.join(os.getcwd(), "test_log.txt") print(f"\npython {log_file} >& {log_text}") arg = f"{sys.executable} {log_file}" args = arg.split() new_env = os.environ.copy() new_env.update(env_set) with open(log_text, "w") as fp: sub = subprocess.Popen(args=args, shell=False, stdout=fp, stderr=fp, env=new_env) sub.wait() with open(log_text, "r") as fp: lines = fp.read() find_info = (lines.find("[INFO]") != -1) find_warning = (lines.find("[WARNING]") != -1) find_error = (lines.find("[ERROR]") != -1) print("log_text:------------------") print(lines) print("log_text end------------------") os.system(f"rm -f {log_file} {log_text}") return find_info, find_warning, find_error def start_new_log_process_py(env_set): """start new process with python log""" log_py_context = r""" from mindspore_serving import log as logger from mindspore_serving import server def log_process(): logger.info("info msg test") logger.warning("warning msg test") logger.error("error msg test") logger.debug("debug msg test") log_process() """ return start_new_log_process(log_py_context, env_set) def start_new_log_process_cpp(env_set): """start new process with cpp log""" log_py_context = r""" from mindspore_serving import log as logger from mindspore_serving import server def log_process(): # info server.start_grpc_server("0.0.0.0:5500") try: # error server.start_grpc_server("0.0.0.0:5500") except RuntimeError: pass log_process() """ return start_new_log_process(log_py_context, env_set) @serving_test def test_log_level_python_debug(): find_info, find_warning, find_error = start_new_log_process_py({"GLOG_v": "0"}) assert find_info assert find_warning assert find_error @serving_test def test_log_level_python_info(): find_info, find_warning, find_error = start_new_log_process_py({"GLOG_v": "1"}) assert find_info assert find_warning assert find_error @serving_test def test_log_level_python_warning(): find_info, find_warning, find_error = start_new_log_process_py({"GLOG_v": "2"}) assert not find_info assert find_warning assert find_error @serving_test def test_log_level_python_error(): find_info, find_warning, find_error = start_new_log_process_py({"GLOG_v": "3"}) assert not find_info assert not find_warning assert find_error @serving_test def test_log_level_cpp_debug(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "0"}) assert find_info assert find_error @serving_test def test_log_level_cpp_info(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "1"}) assert find_info assert find_error @serving_test def test_log_level_cpp_warning(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "2"}) assert not find_info assert find_error @serving_test def test_log_level_cpp_error(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "3"}) assert not find_info assert find_error @serving_test def test_log_level_cpp_debug2(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "3", "MS_SUBMODULE_LOG_v": "{SERVING:0}"}) assert find_info assert find_error @serving_test def test_log_level_cpp_info2(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "3", "MS_SUBMODULE_LOG_v": "{SERVING:1}"}) assert find_info assert find_error @serving_test def test_log_level_cpp_warning2(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "3", "MS_SUBMODULE_LOG_v": "{SERVING:2}"}) assert not find_info assert find_error @serving_test def test_log_level_cpp_error2(): find_info, _, find_error = start_new_log_process_cpp({"GLOG_v": "3", "MS_SUBMODULE_LOG_v": "{SERVING:3}"}) assert not find_info assert find_error ================================================ FILE: tests/ut/python/tests/test_stage_function.py ================================================ # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import numpy as np from common import serving_test, create_client from common import start_serving_server def is_float_equal(left, right): return (np.abs(left - right) < 0.00001).all() @serving_test def test_stage_function_one_function_stage_float_success(): """ Feature: test servable_config.py stage Description: Test stage with two inputs, one output Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def test_concat(x1, x2): return x1 + x2 @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(test_concat, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] x1s = [] x2s = [] x1s.append(np.array([[101.1, 205.2], [41.3, 62.4]], np.float32)) x2s.append(np.array([[3.5, 5.6], [7.7, 9.8]], np.float32)) x1s.append(np.array([[41.3, 32.2], [4.1, 3.9]], np.float32)) x2s.append(np.array([[1.4, 4.5], [9.6, 19.7]], np.float32)) x1s.append(np.array([[11.1, 21.2], [41.9, 61.8]], np.float32)) x2s.append(np.array([[31.5, 51.7], [71.4, 91.3]], np.float32)) for i in range(3): instances.append({"x1": x1s[i], "x2": x2s[i]}) y = x1s[i] + x2s[i] ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y"] == ys[0]).all() assert (result[1]["y"] == ys[1]).all() assert (result[2]["y"] == ys[2]).all() @serving_test def test_stage_function_one_function_stage_two_output_success(): """ Feature: test servable_config.py stage Description: Test stage with one input, two outputs Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def test_concat(x1): return x1 + 1, x1-1 @register.register_method(output_names=["y1", "y2"]) def predict(x1): y1, y2 = register.add_stage(test_concat, x1, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] y1s = [] y2s = [] x1s = [] x1s.append(np.array([[101.1, 205.2], [41.3, 62.4]], np.float32)) x1s.append(np.array([[41.3, 32.2], [4.1, 3.9]], np.float32)) x1s.append(np.array([[11.1, 21.2], [41.9, 61.8]], np.float32)) for i in range(3): instances.append({"x1": x1s[i]}) y1s.append(x1s[i] + 1) y2s.append(x1s[i] - 1) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert (result[0]["y1"] == y1s[0]).all() assert (result[1]["y1"] == y1s[1]).all() assert (result[2]["y1"] == y1s[2]).all() assert (result[0]["y2"] == y2s[0]).all() assert (result[1]["y2"] == y2s[1]).all() assert (result[2]["y2"] == y2s[2]).all() @serving_test def test_stage_function_one_function_stage_output_more_failed(): """ Feature: test servable_config.py stage Description: Test stage declared outputs_count < python function outputs count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): return x1+x2, x1-x2, 1 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_output_less_failed(): """ Feature: test servable_config.py stage Description: Test stage declared outputs_count > python function outputs count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): return x1+x2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_error_outputs_count_failed(): """ Feature: test servable_config.py stage Description: Test stage declared outputs_count > python function outputs count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): return x1+x2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=3) return y1, y2 """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "too many values to unpack (expected 2)" in str(e) @serving_test def test_stage_function_one_function_stage_error_outputs_count2_failed(): """ Feature: test servable_config.py stage Description: Test stage declared outputs_count < python function outputs count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): return x1+x2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=1) return y1, y2 """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "cannot unpack non-iterable _TensorDef object" in str(e) @serving_test def test_stage_function_one_function_stage_input_more_failed(): """ Feature: test servable_config.py stage Description: Test stage declared inputs count < python function inputs count Expectation: Serving server startup error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2, x3): return x1, x2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "function func_test input args count 3 not match the count 2 registered in method" in str(e) @serving_test def test_stage_function_one_function_stage_input_less_failed(): """ Feature: test servable_config.py stage Description: Test stage declared inputs count > python function inputs count Expectation: Serving server startup error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1): return x1, x2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ try: start_serving_server(servable_content) assert False except RuntimeError as e: assert "function func_test input args count 1 not match the count 2 registered in method" in str(e) @serving_test def test_stage_function_one_function_stage_raise_exception_failed(): """ Feature: test servable_config.py stage Description: Stage python function raise exception Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): raise RuntimeError("runtime error text") @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_none_outputs_failed(): """ Feature: test servable_config.py stage Description: Stage python function return None Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): print("none outputs") @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_invalid_output_dtype_failed(): """ Feature: test servable_config.py stage Description: Stage python function return invalid data, dtype is not supported Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test(x1, x2): return x1.dtype, x2.dtype @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test, x1, x2, outputs_count=2) return y1, y2 """ base = start_serving_server(servable_content) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) instances = [{"x1": x1, "x2": x2}] * 3 client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, and result output count is 1, tuple/list Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] results.append([y]) return results @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_batch_size2_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, and result output count is 1, not tuple/list Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] results.append(y) return results @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_batch_size3_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, use yield, not tuple/list Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] yield y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_batch_size4_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, use yield, use tuple/list Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] yield [y] @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_batch_size_equal1_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, batch size = 1 Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] yield y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=1) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_batch_size_0_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, batch size=0, batch size is determined by system Expectation: Serving server work ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] yield y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=0) return y """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] ys = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y = x1 + x2 instances.append({"x1": x1, "x2": x2}) ys.append(y) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y"], ys[0]) assert is_float_equal(result[1]["y"], ys[1]) assert is_float_equal(result[2]["y"], ys[2]) @serving_test def test_stage_function_one_function_stage_error_batch_size_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, batch size is invalid Expectation: Serving server startup failed. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y = instance[0] + instance[1] yield y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=-1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir") assert False except RuntimeError as e: assert "Parameter 'batch_size' should be >= 0" in str(e) @serving_test def test_stage_function_one_function_stage_batch_size_two_outputs_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, yield, result outputs count is 2 Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] yield y1, y2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] y1s = [] y2s = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y1 = x1 + x2 y2 = x1 - x2 instances.append({"x1": x1, "x2": x2}) y1s.append(y1) y2s.append(y2) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y1"], y1s[0]) assert is_float_equal(result[1]["y1"], y1s[1]) assert is_float_equal(result[2]["y1"], y1s[2]) assert is_float_equal(result[0]["y2"], y2s[0]) assert is_float_equal(result[1]["y2"], y2s[1]) assert is_float_equal(result[2]["y2"], y2s[2]) @serving_test def test_stage_function_one_function_stage_batch_size_two_outputs_multi_times_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, multi stage Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] yield y1, y2 @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=2) y1, y2 = register.add_stage(func_test_batch, y1, y2, outputs_count=2, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] y1s = [] y2s = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y1, y2 = x1 + x2, x1 - x2 y1, y2 = y1 + y2, y1 - y2 instances.append({"x1": x1, "x2": x2}) y1s.append(y1) y2s.append(y2) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y1"], y1s[0]) assert is_float_equal(result[1]["y1"], y1s[1]) assert is_float_equal(result[2]["y1"], y1s[2]) assert is_float_equal(result[0]["y2"], y2s[0]) assert is_float_equal(result[1]["y2"], y2s[1]) assert is_float_equal(result[2]["y2"], y2s[2]) @serving_test def test_stage_function_one_function_stage_batch_size_two_outputs2_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, result output count is 2 Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2 = register.add_stage(func_test_batch, x1, x2, outputs_count=2, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] y1s = [] y2s = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) y1 = x1 + x2 y2 = x1 - x2 instances.append({"x1": x1, "x2": x2}) y1s.append(y1) y2s.append(y2) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y1"], y1s[0]) assert is_float_equal(result[1]["y1"], y1s[1]) assert is_float_equal(result[2]["y1"], y1s[2]) assert is_float_equal(result[0]["y2"], y2s[0]) assert is_float_equal(result[1]["y2"], y2s[1]) assert is_float_equal(result[2]["y2"], y2s[2]) @serving_test def test_stage_function_one_function_stage_batch_size_input_more_success(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, used inputs count 2 < declared inputs count 3 Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2, x3): y1, y2 = register.add_stage(func_test_batch, x1, x2, x3, outputs_count=2, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] y1s = [] y2s = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x3 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) y1 = x1 + x2 y2 = x1 - x2 instances.append({"x1": x1, "x2": x2, "x3": x3}) y1s.append(y1) y2s.append(y2) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) assert is_float_equal(result[0]["y1"], y1s[0]) assert is_float_equal(result[1]["y1"], y1s[1]) assert is_float_equal(result[2]["y1"], y1s[2]) assert is_float_equal(result[0]["y2"], y2s[0]) assert is_float_equal(result[1]["y2"], y2s[1]) assert is_float_equal(result[2]["y2"], y2s[2]) @serving_test def test_stage_function_one_function_stage_batch_size_input_less_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, used inputs count 2 > declared inputs count 1 Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y1", "y2"]) def predict(x1): y1, y2 = register.add_stage(func_test_batch, x1, outputs_count=2, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_output_more_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, outputs count 2 < declared outputs_count 3 Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y1", "y2"]) def predict(x1, x2): y1, y2, y3 = register.add_stage(func_test_batch, x1, x2, outputs_count=3, batch_size=2) return y1, y2 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_output_less_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, outputs count 2 > declared outputs_count 1 Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1, y2]) return results @register.register_method(output_names=["y1"]) def predict(x1, x2): y1 = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y1 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_output_less2_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, outputs count 2 > declared outputs_count 1, yield Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] yield y1, y2 @register.register_method(output_names=["y1"]) def predict(x1, x2): y1 = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y1 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_raise_exception_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, raise exception Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): raise RuntimeError("runtime error test") @register.register_method(output_names=["y1"]) def predict(x1, x2): y1 = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y1 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_none_return_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return None Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): pass @register.register_method(output_names=["y1"]) def predict(x1, x2): y1 = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y1 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_stage_function_one_function_stage_batch_size_invalid_output_dtype_failed(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return invalid data Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def func_test_batch(instances): results = [] for instance in instances: y1 = instance[0] + instance[1] y2 = instance[0] - instance[1] results.append([y1.dtype, y2.dtype]) return results @register.register_method(output_names=["y1"]) def predict(x1, x2): y1 = register.add_stage(func_test_batch, x1, x2, outputs_count=1, batch_size=2) return y1 """ base = start_serving_server(servable_content, model_file="tensor_add.mindir") # Client instances = [] for i in range(3): x1 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) * 1.1 * (i + 1) x2 = np.array([[1.5, 2.6], [3.7, 4.8]], np.float32) * 1.1 * (i + 1) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "predict") result = client.infer(instances) print("result", result) if isinstance(result, dict): assert "servable is not available" in result["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result["error"] else: assert "servable is not available" in result[0]["error"] \ or f"Call Function '{base.servable_name}.func_test_batch' Failed" in result[0]["error"] @serving_test def test_servable_postprocess_result_count_less(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return instances count less then input instances count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def postprocess(instances): count = len(instances) for i in range(count -1): yield i @register.register_method(output_names=["y"]) def add_common(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) y = register.add_stage(postprocess, y, outputs_count=1, batch_size=4, tag="Postprocess") return y """ base = start_serving_server(servable_content) # Client instance_count = 2 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) assert "Postprocess Failed" in str(result[1]["error"]) or 'servable is not available' in str(result[1]["error"]) @serving_test def test_servable_postprocess_result_count_more(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return instances count more then input instances count Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def postprocess(instances): count = len(instances) for i in range(count + 1): yield i @register.register_method(output_names=["y"]) def add_common(x1, x2): y = register.add_stage(model, x1, x2, outputs_count=1) y = register.add_stage(postprocess, y, outputs_count=1, batch_size=4, tag="Postprocess") return y """ base = start_serving_server(servable_content) # Client instance_count = 2 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) assert len(result) == instance_count assert result[0]["y"] == 0 assert result[1]["y"] == 1 @serving_test def test_stage_function_preprocess_result_count_less(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return instances count less then input instances count Expectation: Serving server report error. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def preprocess(instances): count = len(instances) for i in range(count-1): yield i @register.register_method(output_names=["y"]) def add_common(x1, x2): x3 = register.add_stage(preprocess, x1, outputs_count=1, batch_size=4, tag="Preprocess") y = register.add_stage(model, x1, x2, outputs_count=1) return x3 """ base = start_serving_server(servable_content) # Client instance_count = 2 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) if isinstance(result, list): assert "Preprocess Failed" in str(result[1]["error"]) or "servable is not available" in str(result[1]["error"]) else: assert "Preprocess Failed" in str(result["error"]) or "servable is not available" in str(result["error"]) @serving_test def test_stage_function_preprocess_result_count_more(): """ Feature: test servable_config.py stage Description: Stage python function run with batch_size parameter, return instances count more then input instances count Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def preprocess(instances): count = len(instances) for i in range(count+1): yield i @register.register_method(output_names=["y"]) def add_common(x1, x2): x3 = register.add_stage(preprocess, x1, outputs_count=1, batch_size=4, tag="Preprocess") y = register.add_stage(model, x1, x2, outputs_count=1) return x3 """ base = start_serving_server(servable_content) # Client instance_count = 3 instances = [] y_data_list = [] for i in range(instance_count): x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) y_data_list.append(x1 + x2) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) assert len(result) == instance_count @serving_test def test_stage_function_push_no_forc_array(): """ Feature: test servable_config.py stage Description: Preprocess return numpy array not C_CONTIGUOUS Expectation: Serving server work well. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=True) def preprocess(x1): x1 = x1.reshape(3,4) x1 = x1[1:3,2:3] return x1 @register.register_method(output_names=["y"]) def add_common(x1, x2): x1 = register.add_stage(preprocess, x1, outputs_count=1, tag="Preprocess") y = register.add_stage(model, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content) instances = [] x1 = np.arange(12).astype(np.float32) x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) instances.append({"x1": x1, "x2": x2}) client = create_client("localhost:5500", base.servable_name, "add_common") result = client.infer(instances) print(result) assert len(result) == 1 assert "y" in result[0] ================================================ FILE: tests/ut/python/tests/test_start_servable_config.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving: test servable_config""" from common import ServingTestBase, serving_test from mindspore_serving import server # test servable_config.py servable_config_import = r""" import numpy as np from mindspore_serving.server import register """ servable_config_declare_servable = r""" register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) """ servable_config_preprocess_cast = r""" def add_trans_datatype(x1, x2): return x1.astype(np.float32), x2.astype(np.float32) """ servable_config_method_add_common = r""" @register.register_method(output_names=["y"]) def add_common(x1, x2): # only support float32 inputs y = register.call_servable(x1, x2) return y """ servable_config_method_add_cast = r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32 y = register.call_servable(x1, x2) return y """ @serving_test def test_register_method_common_success(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += servable_config_method_add_common servable_content += servable_config_method_add_cast base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) @serving_test def test_register_method_no_declare_servable_failed(): base = ServingTestBase() servable_content = servable_config_import # servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += servable_config_method_add_common servable_content += servable_config_method_add_cast base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "There is no model declared, you can use declare_model to declare models" in str(e) @serving_test def test_register_method_reference_invalid_preprocess_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable # servable_content += servable_config_preprocess_cast servable_content += servable_config_method_add_common servable_content += servable_config_method_add_cast base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "name 'add_trans_datatype' is not defined" in str(e) # preprocess order error @serving_test def test_register_method_preprocess_after_predict_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) return x1 """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_servable should be invoked after call_preprocess" in str(e) @serving_test def test_register_method_preprocess_after_postprocess_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess" in str(e) @serving_test def test_register_method_preprocess_after_postprocess_pipeline_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess_pipeline(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess" in str(e) # preprocess_pipeline order error @serving_test def test_register_method_preprocess_pipeline_after_predict_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) return x1 """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_servable should be invoked after call_preprocess_pipeline" in str(e) @serving_test def test_register_method_preprocess_pipeline_after_postprocess_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess_pipeline" \ in str(e) @serving_test def test_register_method_preprocess_pipeline_after_postprocess_pipeline_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess_pipeline(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_preprocess_pipeline" \ in str(e) # repeat preprocess @serving_test def test_register_method_preprocess_twice_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_preprocess or call_preprocess_pipeline should not be invoked more than once" in str(e) @serving_test def test_register_method_preprocess_twice2_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_preprocess or call_preprocess_pipeline should not be invoked more than once" in str(e) @serving_test def test_register_method_preprocess_pipeline_twice_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) x1, x2 = register.call_preprocess_pipeline(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_preprocess or call_preprocess_pipeline should not be invoked more than once" in str(e) # repeat postprocess @serving_test def test_register_method_postprocess_twice_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" def postprocess(y): return y.astype(np.int32) @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) y = register.call_postprocess(postprocess, y) y = register.call_postprocess(postprocess, y) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should not be invoked more than once" in str(e) @serving_test def test_register_method_postprocess_twice2_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" def postprocess(y): return y.astype(np.int32) @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) y = register.call_postprocess_pipeline(postprocess, y) y = register.call_postprocess(postprocess, y) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should not be invoked more than once" in str(e) @serving_test def test_register_method_postprocess_pipeline_twice_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" def postprocess(y): return y.astype(np.int32) @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.call_servable(x1, x2) y = register.call_postprocess_pipeline(postprocess, y) y = register.call_postprocess_pipeline(postprocess, y) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should not be invoked more than once" in str(e) # call servable repeat @serving_test def test_register_method_call_servable_twice_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_servable should not be invoked more than once" in str(e) @serving_test def test_register_method_call_servable_after_postprocess_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_servable" in str(e) @serving_test def test_register_method_call_servable_after_postprocess_pipeline_failed(): base = ServingTestBase() servable_content = servable_config_import servable_content += servable_config_declare_servable servable_content += servable_config_preprocess_cast servable_content += r""" @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_postprocess_pipeline(add_trans_datatype, x1, x2) y = register.call_servable(x1, x2) return y """ base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "call_postprocess or call_postprocess_pipeline should be invoked after call_servable" in str(e) @serving_test def test_register_method_without_call_servable_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register register.declare_servable(servable_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def add_trans_datatype(x1, x2): return x1.astype(np.float32), x2.astype(np.float32) def add_func(x1, x2): return x1+x2 @register.register_method(output_names=["y"]) def add_cast(x1, x2): x1, x2 = register.call_preprocess(add_trans_datatype, x1, x2) # cast input to float32 y = register.call_postprocess(add_func, x1, x2) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "Not find the invoke of 'call_servable'" in str(e) @serving_test def test_register_method_invalid_call_servable(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_servable(model, x1, x2) return y return x1 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "conditions and loops are not supported in register_method when the interface 'call_servable' is used," \ " use 'add_stage' to replace 'call_servable'" in str(e) @serving_test def test_register_method_invalid_call_servable2(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) model2 = register.declare_model(model_file="tensor_add2.mindir", model_format="MindIR", with_batch_dim=False) @register.register_method(output_names="y") def predict(x1, x2): y = register.call_servable(x1, x2) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "There are more than one servable declared when the interface 'call_servable' is used, use 'add_stage'" \ " to replace 'call_servable'" in str(e) @serving_test def test_register_method_invalid_call_preprocess(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_preprocess(preprocess, x1, x2) return y return x1 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "conditions and loops are not supported in register_method when the interface 'call_preprocess'" \ " is used, use 'add_stage' to replace 'call_preprocess'" in str(e) @serving_test def test_register_method_invalid_call_preprocess_pipeline(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_preprocess_pipeline(preprocess, x1, x2) return y return x1 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "conditions and loops are not supported in register_method when the interface" \ " 'call_preprocess_pipeline' is used, use 'add_stage' to replace 'call_preprocess_pipeline'" in str(e) @serving_test def test_register_method_invalid_call_postprocess(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_postprocess(preprocess, x1, x2) return y return x1 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "conditions and loops are not supported in register_method when the interface 'call_postprocess'" \ " is used, use 'add_stage' to replace 'call_postprocess'" in str(e) @serving_test def test_register_method_invalid_call_postprocess_pipeline(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_postprocess_pipeline(preprocess, x1, x2) return y return x1 """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "conditions and loops are not supported in register_method when the interface " \ "'call_postprocess_pipeline' is used, use 'add_stage' to replace 'call_postprocess_pipeline'" in str(e) @serving_test def test_register_method_invalid_call_preprocess_with_condition(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): y = register.call_preprocess(preprocess, x1, x2) if True: y = register.call_postprocess(preprocess, x1, x2) return y return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "complex statements such as conditions and loops are not supported in register_method when the " \ "interface 'call_preprocess' is used, use 'add_stage' to replace 'call_preprocess'" in str(e) @serving_test def test_register_method_invalid_call_preprocess_with_condition2(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_postprocess(preprocess, x1, x2) return y y = register.call_preprocess(preprocess, x1, x2) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "complex statements such as conditions and loops are not supported in register_method when the " \ "interface 'call_preprocess' is used, use 'add_stage' to replace 'call_preprocess'" in str(e) @serving_test def test_register_method_mix_call_xxx_add_stage_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.call_preprocess(preprocess, x1, x2) y = register.call_servable(y, x3) y = register.call_postprocess(preprocess, y, x4) y = register.add_stage(preprocess, y, x2, outputs_count=1) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "complex statements such as conditions and loops are not supported in register_method when the" in str(e) @serving_test def test_register_method_mix_call_xxx_add_stage2_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2, x3, x4): y = register.add_stage(preprocess, x1, x2, outputs_count=1) y = register.call_preprocess(preprocess, y, x2) y = register.call_servable(y, x3) y = register.call_postprocess(preprocess, y, x4) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "complex statements such as conditions and loops are not supported in register_method when the" in str(e) @serving_test def test_register_method_mix_call_xxx_add_stage3_failed(): servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR", with_batch_dim=False) def preprocess(x1, x2): return y @register.register_method(output_names="y") def predict(x1, x2): if True: y = register.call_postprocess(preprocess, x1, x2) return y y = register.add_stage(preprocess, x1, x2, outputs_count=1) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "complex statements such as conditions and loops are not supported in register_method when the" in str(e) ================================================ FILE: tests/ut/python/tests/test_start_sevables.py ================================================ # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test Serving with master, worker and client""" import shutil import os import numpy as np from common import ServingTestBase, serving_test, start_serving_server, create_client from mindspore_serving import server from mindspore_serving.server._servable_local import merge_config @serving_test def test_start_servable_servable_dir_invalid_failed(): """ Feature: test start servables Description: servable dir is not exist Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir + "_error", base.servable_name, device_ids=0)) assert False except RuntimeError as e: assert "Check servable config failed, directory " in str(e) # start_servable @serving_test def test_start_worker_no_servable_config_file_failed(): """ Feature: test start servables Description: servable_config.py is not exist Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "no_exist_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "Check servable config failed, file " in str(e) @serving_test def test_start_worker_no_model_file_failed(): """ Feature: test start servables Description: model file is not exist Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py", model_file="tensor_add_error.mindir") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "Load model failed, servable directory: " in str(e) @serving_test def test_start_servable_servable_dir_empty_invalid_failed(): """ Feature: test start servables Description: input parameter 'servable_directory' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables(server.ServableStartConfig("", base.servable_name, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "Parameter 'servable_directory' should not be empty str" in str(e) @serving_test def test_start_worker_type_servable_dir_invalid_failed(): """ Feature: test start servables Description: input parameter 'servable_directory' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables(server.ServableStartConfig(1, base.servable_name, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "Parameter 'servable_directory' should be str, but actually " in str(e) @serving_test def test_start_worker_type_servable_name_invalid_failed(): """ Feature: test start servables Description: input parameter 'servable_name' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables(server.ServableStartConfig(base.servable_dir, False, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "Parameter 'servable_name' should be str, but actually " in str(e) @serving_test def test_start_servable_version_number_invalid_failed(): """ Feature: test start servables Description: There is no specified version model Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=2)) assert False except RuntimeError as e: assert "There is no specified version directory of models, specified version number: 2" in str(e) @serving_test def test_start_servable_version_number_invalid2_failed(): """ Feature: test start servables Description: There is no valid version directory Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(0, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=0)) assert False except RuntimeError as e: assert "There is no valid version directory of models" in str(e) @serving_test def test_start_worker_type_version_number_invalid_failed(): """ Feature: test start servables Description: input parameter 'version_number' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=False)) assert False except RuntimeError as e: assert "Parameter 'version_number' should be int, but actually " in str(e) base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=-1)) assert False except RuntimeError as e: assert "Parameter 'version_number' should be >= 0" in str(e) @serving_test def test_start_worker_type_device_id_invalid_failed(): """ Feature: test start servables Description: input parameter 'device_ids' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, version_number=1, device_ids="1")) assert False except RuntimeError as e: assert "Parameter 'device_ids' should be int or tuple/list of int, but actually" in str(e) @serving_test def test_start_worker_device_id_range_invalid_failed(): """ Feature: test start servables Description: input parameter 'device_ids' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, version_number=1, device_ids=-1)) assert False except RuntimeError as e: assert "The item value '-1' in parameter 'device_ids' should be >= 0" in str(e) @serving_test def test_start_worker_type_device_type_invalid_failed(): """ Feature: test start servables Description: input parameter 'device_type' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, version_number=1, device_type=123)) assert False except RuntimeError as e: assert "Parameter 'device_type' should be str, but actually" in str(e) @serving_test def test_start_worker_device_type_value_invalid_failed(): """ Feature: test start servables Description: input parameter 'device_type' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="InvalidDeviceType")) assert False except RuntimeError as e: assert "Unsupported device type 'InvalidDeviceType', only support 'Ascend', 'GPU', 'CPU' and None, " \ "case ignored" in str(e) @serving_test def test_start_worker_device_type_value_invalid2_failed(): """ Feature: test start servables Description: input parameter 'device_type' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="")) assert False except RuntimeError as e: assert "Parameter 'device_type' should not be empty str" in str(e) @serving_test def test_start_worker_type_device_type_none_success(): """ Feature: test start servables Description: input parameter 'device_type' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type=None)) @serving_test def test_start_worker_type_device_type_none2_success(): """ Feature: test start servables Description: input parameter 'device_type' invalid Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type='None')) @serving_test def test_servable_start_config_merge_same_version_same_device_ids_success(): """ Feature: test merge servable config Description: specified version 1 and newest version 0 can merge to one config of version 1 Expectation: success to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=1) config_ret = merge_config((config0, config1)) assert len(config_ret) == 1 assert config_ret[0].version_number == 1 assert len(config_ret[0].device_ids) == 1 assert config_ret[0].device_ids[0] == 2 @serving_test def test_servable_start_config_merge_same_version_diff_device_ids_success(): """ Feature: test merge servable config Description: specified version 1 with diff device can merge to one config with device_ids merged Expectation: success to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=1, version_number=1) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=(0, 2), version_number=1) config_ret = merge_config((config0, config1)) assert len(config_ret) == 1 assert config_ret[0].version_number == 1 assert len(config_ret[0].device_ids) == 3 assert 0 in config_ret[0].device_ids assert 1 in config_ret[0].device_ids assert 2 in config_ret[0].device_ids @serving_test def test_servable_start_config_merge_diff_version_diff_device_ids_success(): """ Feature: test merge servable config Description: specified version 1 and newest version 0 with diff device can merge to one config of version 1 with device_ids merged Expectation: success to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name, "1"), os.path.join(base.servable_dir, base.servable_name, "2")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=1, version_number=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=1) config_ret = merge_config((config0, config1)) assert len(config_ret) == 2 assert config_ret[0].version_number == 2 # newest version assert len(config_ret[0].device_ids) == 1 assert config_ret[0].device_ids[0] == 1 assert config_ret[1].version_number == 1 assert len(config_ret[1].device_ids) == 1 assert config_ret[1].device_ids[0] == 2 @serving_test def test_servable_start_config_merge_diff_version_same_device_ids_failed(): """ Feature: test merge servable config Description: specified version 1 and newest version 0 with same device is invalid Expectation: failed to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name, "1"), os.path.join(base.servable_dir, base.servable_name, "2")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=1) try: server.start_servables((config0, config1)) assert False except RuntimeError as e: assert "Ascend 910 device id 2 is used repeatedly in servable" in str(e) @serving_test def test_servable_start_config_same_servable_name_diff_directory_failed(): """ Feature: test merge servable config Description: specified version 1 and newest version 0 with diff servable directory is invalid Expectation: failed to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=0) config1 = server.ServableStartConfig(base.servable_dir + "2", base.servable_name, device_ids=2, version_number=1) try: server.start_servables((config0, config1)) assert False except RuntimeError as e: assert f"The servable directory of servable name {base.servable_name} is different in multiple configurations" \ in str(e) @serving_test def test_servable_start_config_multi_servable_same_device_id(): """ Feature: test merge servable config Description: diff servable same with same device id is invalid Expectation: failed to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "2")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "2", device_ids=2, version_number=1) try: server.start_servables((config0, config1)) assert False except RuntimeError as e: assert "Ascend 910 device id 2 is used repeatedly in servable" in str(e) @serving_test def test_servable_start_config_multi_servable_diff_device_id(): """ Feature: test merge servable config Description: servable name are same, some are diff Expectation: success to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "2")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=(1, 3), version_number=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "2", device_ids=2, version_number=1) config3 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=(4, 5), version_number=0) config_ret = merge_config((config0, config1, config3)) assert len(config_ret) == 2 print(config_ret[0].servable_name) print(config_ret[1].servable_name) assert config_ret[0].version_number == 1 # newest version assert len(config_ret[0].device_ids) == 4 assert tuple(config_ret[0].device_ids) == (1, 3, 4, 5) assert config_ret[1].version_number == 1 assert len(config_ret[1].device_ids) == 1 assert config_ret[1].device_ids[0] == 2 @serving_test def test_servable_start_config_merge_diff_version_diff_dec_key_success(): """ Feature: test merge servable config Description: diff version with diff dec key Expectation: success to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name, "1"), os.path.join(base.servable_dir, base.servable_name, "2")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=1, version_number=0, dec_key=("ABC" * 8).encode(), dec_mode='AES-GCM') config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=1, dec_key=("DEF" * 8).encode(), dec_mode='AES-CBC') config_ret = merge_config((config0, config1)) assert len(config_ret) == 2 assert config_ret[0].dec_key == ("ABC" * 8).encode() # newest version assert config_ret[0].dec_mode == "AES-GCM" assert config_ret[1].dec_key == ("DEF" * 8).encode() # newest version assert config_ret[1].dec_mode == "AES-CBC" @serving_test def test_servable_start_config_merge_same_version_diff_dec_key_failed(): """ Feature: test merge servable config Description: same version with diff dec key Expectation: failed to merge config. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=1, version_number=0, dec_key=("ABC" * 8).encode(), dec_mode='AES-GCM') config1 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=2, version_number=1, dec_key=("DEF" * 8).encode(), dec_mode='AES-CBC') try: server.start_servables((config0, config1)) assert False except RuntimeError as e: assert "The dec key or dec mode of servable name" in str(e) @serving_test def test_servable_start_config_with_dec_success(): """ Feature: test start servable with dec Description: test start servable with dec Expectation: success to start servable. """ servable_content = r""" import numpy as np from mindspore_serving.server import register tensor_add = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") @register.register_method(output_names=["y"]) def add_cast(x1, x2): y = register.add_stage(tensor_add, x1, x2, outputs_count=1) return y """ base = ServingTestBase() base.init_servable_with_servable_config(1, servable_content) server.start_servables(server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, dec_key="ABCDEFGHABCDEFGH".encode(), dec_mode='AES-GCM')) @serving_test def test_start_servables_without_declared_model_none_device_ids_start_version0_success(): """ Feature: test start servables Description: no models, no device ids, with extra workers, no version directory, start version number 0 Expectation: serving server running ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=[], device_ids=None, version_number=0, start_version_number=0) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "predict", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_without_declared_model_none_device_ids_start_version1_success(): """ Feature: test start servables Description: no models, no device ids, with extra workers, no version directory, start version number 1 Expectation: serving server running ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=[], device_ids=None, version_number=0, start_version_number=1) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 10 client = create_client("localhost:5500", base.servable_name, "predict", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_without_declared_model_with_device_ids_start_version0_success(): """ Feature: test start servables Description: no models, with device ids, without extra workers, no version directory, start version number 0 Expectation: serving server running ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=[], device_ids=0, version_number=0, start_version_number=0) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 10 client = create_client("localhost:5500", base.servable_name, "predict", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_without_declared_model_with_device_ids_start_version0_with_extra_worker_success(): """ Feature: test start servables Description: no models, with device ids, without extra workers, no version directory, start version number 0 Expectation: serving server running ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=[], device_ids=0, num_parallel_workers=2, version_number=0, start_version_number=0) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 10 client = create_client("localhost:5500", base.servable_name, "predict", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_without_declared_model_with_device_ids_start_version1_with_extra_worker_success(): """ Feature: test start servables Description: no models, with device ids, with extra workers, no version directory, start version number 1 Expectation: serving server running ok. """ servable_content = r""" import numpy as np from mindspore_serving.server import register def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ base = start_serving_server(servable_content, model_file=[], device_ids=0, num_parallel_workers=2, version_number=0, start_version_number=1) # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] * 10 client = create_client("localhost:5500", base.servable_name, "predict", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_with_declared_model_none_device_ids_start_version0_with_extra_worker_fail(): """ Feature: test start servables Description: with models, none device ids, with extra workers, no version directory, start version number 0 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=None, num_parallel_workers=2, version_number=None, start_version_number=0) assert False except RuntimeError as e: assert "There is no valid version directory of models" in str(e) @serving_test def test_start_servables_with_declared_model_none_device_ids_start_version1_with_extra_worker_fail(): """ Feature: test start servables Description: with models, none device ids, with extra workers, no version directory, start version number 1 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=None, num_parallel_workers=2, version_number=None, start_version_number=1) assert False except RuntimeError as e: assert "There is no valid version directory of models" in str(e) @serving_test def test_start_servables_with_declared_model_none_device_ids_start_version0_with_version_dir_fail(): """ Feature: test start servables Description: with models, none device ids, with extra workers, with version directory, start version number 1 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=None, num_parallel_workers=2, version_number=1, start_version_number=0) assert False except RuntimeError as e: # "Servable '{}' has models declared by declare_model, but parameter 'device_ids'" assert " has models declared by declare_model, but parameter 'device_ids'" in str(e) @serving_test def test_start_servables_with_declared_model_none_device_ids_start_version1_with_version_dir_fail(): """ Feature: test start servables Description: with models, none device ids, with extra workers, with version directory, start version number 1 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=None, num_parallel_workers=2, version_number=1, start_version_number=1) assert False except RuntimeError as e: # "Servable '{}' has models declared by declare_model, but parameter 'device_ids'" assert " has models declared by declare_model, but parameter 'device_ids'" in str(e) @serving_test def test_start_servables_with_declared_model_with_device_ids_start_version0_without_version_dir_fail(): """ Feature: test start servables Description: with models, with device ids, with extra workers, without version directory, start version number 0 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=1, num_parallel_workers=2, version_number=None, start_version_number=0) assert False except RuntimeError as e: assert "There is no valid version directory of models" in str(e) @serving_test def test_start_servables_with_declared_model_with_device_ids_start_version1_without_version_dir_fail(): """ Feature: test start servables Description: with models, with device ids, with extra workers, without version directory, start version number 1 Expectation: failed to serving server. """ servable_content = r""" import numpy as np from mindspore_serving.server import register model = register.declare_model(model_file="tensor_add.mindir", model_format="MindIR") def function_test(x1, x2): y = x1+x2 return y @register.register_method(output_names="y") def predict(x1, x2): y = register.add_stage(function_test, x1, x2, outputs_count=1) return y """ try: start_serving_server(servable_content, model_file="tensor_add.mindir", device_ids=1, num_parallel_workers=2, version_number=None, start_version_number=1) assert False except RuntimeError as e: assert "There is no valid version directory of models" in str(e) @serving_test def test_start_servables_enable_cpu_none_device_id_cpu_device_type_success(): """ Feature: test start servables Description: target cpu, device ids none, device type CPU Expectation: serving server running ok. """ os.environ["SERVING_ENABLE_CPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=None, device_type="CPU")) server.start_grpc_server("localhost:5500") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "add_common", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_enable_cpu_none_device_id_none_device_type_none_success(): """ Feature: test start servables Description: enable cpu, device ids none, device type none Expectation: serving server running ok. """ os.environ["SERVING_ENABLE_CPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=None, device_type=None)) server.start_grpc_server("localhost:5500") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "add_common", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_enable_cpu_device_type_with_device_id_cpu_device_type_success(): """ Feature: test start servables Description: target cpu, with device ids, device type CPU Expectation: serving server running ok. """ os.environ["SERVING_ENABLE_CPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="CPU")) server.start_grpc_server("localhost:5500") # Client x1 = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) x2 = np.array([[5.5, 6.6], [7.7, 8.8]], np.float32) y = x1 + x2 instances = [{"x1": x1, "x2": x2}] client = create_client("localhost:5500", base.servable_name, "add_common", version_number=1) result = client.infer(instances) print("result", result) assert (result[0]["y"] == y).all() @serving_test def test_start_servables_ascend_device_reuse_device_ids_failed(): """ Feature: test start servables Description: Ascend device, target device type Ascend, reuse device failed Expectation: Serving server startup failed. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) try: config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="Ascend") config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=0, device_type="Ascend") server.start_servables([config0, config1]) assert False except RuntimeError as e: assert "Ascend 910 device id 0 is used repeatedly in servable" in str(e) @serving_test def test_start_servables_ascend_device_reuse_device_ids_none_device_type_failed(): """ Feature: test start servables Description: Ascend device, target device type none, reuse device failed Expectation: Serving server startup failed. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) try: config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=0) server.start_servables([config0, config1]) assert False except RuntimeError as e: assert "Ascend 910 device id 0 is used repeatedly in servable" in str(e) @serving_test def test_start_servables_ascend_device_without_reuse_device_ids_none_device_type_success(): """ Feature: test start servables Description: Ascend device, target device type Ascend, without reuse device success Expectation: Serving server work well. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="Ascend") config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=1, device_type="Ascend") server.start_servables([config0, config1]) @serving_test def test_start_servables_gpu_device_reuse_device_ids_success(): """ Feature: test start servables Description: GPU device, target device type GPU, reuse device success Expectation: Serving server work well. """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="GPU") config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=0, device_type="GPU") server.start_servables([config0, config1]) @serving_test def test_start_servables_gpu_device_reuse_device_ids_none_device_type_success(): """ Feature: test start servables Description: GPU device, target device type GPU, reuse device success Expectation: Serving server work well. """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0) config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=0) server.start_servables([config0, config1]) @serving_test def test_start_servables_gpu_device_ascend_device_type_failed(): """ Feature: test start servables Description: GPU device, target device type Ascend Expectation: Serving server start failed. """ os.environ["SERVING_ENABLE_GPU_DEVICE"] = "1" base = ServingTestBase() base.init_servable(1, "add_servable_config.py") shutil.copytree(os.path.join(base.servable_dir, base.servable_name), os.path.join(base.servable_dir, base.servable_name + "_x")) try: config0 = server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, device_type="Ascend") config1 = server.ServableStartConfig(base.servable_dir, base.servable_name + "_x", device_ids=1, device_type="Ascend") server.start_servables([config0, config1]) assert False except RuntimeError as e: assert f"The device type 'ascend' of servable name {base.servable_name} is inconsistent with current " \ f"running environment" in str(e) @serving_test def test_start_servable_number_of_worker_invalid_failed(): """ Feature: test start servables Description: num_parallel_workers not in range[0,64] Expectation: failed to serving server. """ base = ServingTestBase() base.init_servable(1, "add_servable_config.py") try: server.start_servables( server.ServableStartConfig(base.servable_dir, base.servable_name, device_ids=0, num_parallel_workers=65)) assert False except RuntimeError as e: assert "Parameter 'num_parallel_workers' should be in range [0,64]" in str(e) ================================================ FILE: tests/ut/runtest.sh ================================================ #!/bin/bash # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ set -e CURRPATH=$( cd "$(dirname $0)" || exit pwd ) if [ $# -gt 0 ]; then if [ $1 == "python" ]; then echo "run python ut" bash ${CURRPATH}/python/runtest.sh $2 elif [ $1 == "cpp" ]; then echo "run cpp ut" bash ${CURRPATH}/cpp/runtest.sh fi else echo "run all ut" # 1.run python testcases bash ${CURRPATH}/python/runtest.sh $2 # 2.run c++ ut testcases bash ${CURRPATH}/cpp/runtest.sh fi ================================================ FILE: tests/ut/stub/cxx_api/cell.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/cell.h" #include "include/api/context.h" #include "cxx_api/factory.h" #include "cxx_api/graph/graph_impl.h" namespace mindspore { std::vector CellBase::operator()(const std::vector &inputs) const { return Clone()->Construct(inputs); } ParameterCell::ParameterCell(const ParameterCell &cell) { auto tmp_ptr = cell.tensor_.Clone(); tensor_ = *tmp_ptr; MSTensor::DestroyTensorPtr(tmp_ptr); } ParameterCell &ParameterCell::operator=(const ParameterCell &cell) { if (&cell == this) { return *this; } auto tmp_ptr = cell.tensor_.Clone(); tensor_ = *tmp_ptr; MSTensor::DestroyTensorPtr(tmp_ptr); return *this; } ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {} ParameterCell &ParameterCell::operator=(ParameterCell &&cell) { if (&cell == this) { return *this; } tensor_ = cell.tensor_; return *this; } ParameterCell::ParameterCell(const MSTensor &tensor) { auto tmp_ptr = tensor.Clone(); tensor_ = *tmp_ptr; MSTensor::DestroyTensorPtr(tmp_ptr); } ParameterCell &ParameterCell::operator=(const MSTensor &tensor) { auto tmp_ptr = tensor.Clone(); tensor_ = *tmp_ptr; MSTensor::DestroyTensorPtr(tmp_ptr); return *this; } ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {} ParameterCell &ParameterCell::operator=(MSTensor &&tensor) { tensor_ = tensor; return *this; } GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared(graph)) { MS_EXCEPTION_IF_NULL(graph_); } GraphCell::GraphCell(const std::shared_ptr &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); } GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared(graph)) { MS_EXCEPTION_IF_NULL(graph_); } void GraphCell::SetContext(const std::shared_ptr &context) { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; return; } executor_->SetGraph(graph_); } executor_->SetContext(context); } Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; return kMEFailed; } executor_->SetGraph(graph_); } return executor_->Run(inputs, outputs); } Status GraphCell::Load(uint32_t device_id) { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; return kMEFailed; } executor_->SetGraph(graph_); } return executor_->Load(device_id); } std::vector GraphCell::GetInputs() { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; return {}; } executor_->SetGraph(graph_); } return executor_->GetInputs(); } std::vector GraphCell::GetOutputs() { if (executor_ == nullptr) { executor_ = Factory::Instance().Create(g_device_target); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; return {}; } executor_->SetGraph(graph_); } return executor_->GetOutputs(); } InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) { auto tmp_ptr = tensor.Clone(); cell_ = std::make_shared(*tmp_ptr); MSTensor::DestroyTensorPtr(tmp_ptr); } InputAndOutput::InputAndOutput(MSTensor &&tensor) : cell_(std::make_shared(tensor)), prev_(), index_(-1) {} InputAndOutput::InputAndOutput(const std::shared_ptr &cell, const std::vector &prev, int32_t index) : cell_(cell), prev_(prev), index_(index) {} } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/context.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/context.h" #include #include #include #include "cxx_api/factory.h" #include "utils/log_adapter.h" constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16"; constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16"; constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency"; constexpr auto kModelOptionDeviceID = "mindspore.option.device_id"; constexpr auto kModelOptionGPUDeviceID = kModelOptionDeviceID; constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode"; constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape"; constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type"; constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode"; constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode"; constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path"; constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size"; constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize"; namespace mindspore { class Allocator {}; struct Context::Data { std::vector> device_info_list; int32_t thread_num; bool enable_parallel_ = false; std::vector affinity_core_list_; int affinity_mode_ = 2; }; struct DeviceInfoContext::Data { std::map params; }; Context::Context() : data_(std::make_shared()) {} template >> static const U &GetValue(const std::shared_ptr &data, const std::string &key) { static const U empty_result{}; if (data == nullptr) { return empty_result; } auto iter = data->params.find(key); if (iter == data->params.end()) { return empty_result; } const std::any &value = iter->second; if (value.type() != typeid(U)) { return empty_result; } return std::any_cast(value); } void Context::SetThreadNum(int32_t thread_num) { MS_EXCEPTION_IF_NULL(data_); data_->thread_num = thread_num; } int32_t Context::GetThreadNum() const { MS_EXCEPTION_IF_NULL(data_); return data_->thread_num; } void Context::SetEnableParallel(bool is_parallel) { MS_EXCEPTION_IF_NULL(data_); data_->enable_parallel_ = is_parallel; } bool Context::GetEnableParallel() const { MS_EXCEPTION_IF_NULL(data_); return data_->enable_parallel_; } void Context::SetThreadAffinity(int mode) { MS_EXCEPTION_IF_NULL(data_); data_->affinity_mode_ = mode; } int Context::GetThreadAffinityMode() const { MS_EXCEPTION_IF_NULL(data_); return data_->affinity_mode_; } void Context::SetThreadAffinity(const std::vector &core_list) { MS_EXCEPTION_IF_NULL(data_); data_->affinity_core_list_ = core_list; } std::vector Context::GetThreadAffinityCoreList() const { MS_EXCEPTION_IF_NULL(data_); return data_->affinity_core_list_; } std::vector> &Context::MutableDeviceInfo() { MS_EXCEPTION_IF_NULL(data_); return data_->device_info_list; } DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared()) {} void CPUDeviceInfo::SetEnableFP16(bool is_fp16) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionCpuEnableFP16] = is_fp16; } bool CPUDeviceInfo::GetEnableFP16() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionCpuEnableFP16); } void GPUDeviceInfo::SetEnableFP16(bool is_fp16) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionGPUEnableFP16] = is_fp16; } bool GPUDeviceInfo::GetEnableFP16() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionGPUEnableFP16); } void KirinNPUDeviceInfo::SetFrequency(int frequency) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionKirinNpuFrequency] = frequency; } int KirinNPUDeviceInfo::GetFrequency() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionKirinNpuFrequency); } void GPUDeviceInfo::SetDeviceID(uint32_t device_id) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionGPUDeviceID] = device_id; } uint32_t GPUDeviceInfo::GetDeviceID() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionGPUDeviceID); } int GPUDeviceInfo::GetRankID() const { MS_LOG(ERROR) << "Unsupported Feature."; return 0; } int GPUDeviceInfo::GetGroupSize() const { MS_LOG(ERROR) << "Unsupported Feature."; return 0; } void GPUDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode); } std::vector GPUDeviceInfo::GetPrecisionModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionGPUPrecisionMode); return StringToChar(ref); } void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310DeviceID] = device_id; } uint32_t AscendDeviceInfo::GetDeviceID() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionAscend310DeviceID); } void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); } std::vector AscendDeviceInfo::GetInsertOpConfigPathChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InsertOpCfgPath); return StringToChar(ref); } void AscendDeviceInfo::SetInputFormat(const std::vector &format) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputFormat] = CharToString(format); } std::vector AscendDeviceInfo::GetInputFormatChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InputFormat); return StringToChar(ref); } void AscendDeviceInfo::SetInputShape(const std::vector &shape) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputShape] = CharToString(shape); } std::vector AscendDeviceInfo::GetInputShapeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InputShape); return StringToChar(ref); } void AscendDeviceInfo::SetDynamicBatchSize(const std::vector &dynamic_batch_size) { MS_EXCEPTION_IF_NULL(data_); std::string batchs = ""; for (size_t i = 0; i < dynamic_batch_size.size(); ++i) { if (i != 0) { batchs.push_back(','); } batchs += std::to_string(dynamic_batch_size[i]); } data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; } std::vector AscendDeviceInfo::GetDynamicBatchSizeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310DynamicBatchSize); return StringToChar(ref); } void AscendDeviceInfo::SetDynamicImageSize(const std::vector &) { return; } std::vector AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector(); } void AscendDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); } std::vector AscendDeviceInfo::GetPrecisionModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310PrecisionMode); return StringToChar(ref); } void AscendDeviceInfo::SetOpSelectImplMode(const std::vector &op_select_impl_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); } std::vector AscendDeviceInfo::GetOpSelectImplModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310OpSelectImplMode); return StringToChar(ref); } void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(data_); data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); } std::vector AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, KModelOptionAscend310FusionSwitchCfgPath); return StringToChar(ref); } void AscendDeviceInfo::SetInputShapeMap(const std::map> &shape) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputShapeMap] = shape; } std::map> AscendDeviceInfo::GetInputShapeMap() const { MS_EXCEPTION_IF_NULL(data_); return GetValue>>(data_, kModelOptionAscend310InputShapeMap); } void AscendDeviceInfo::SetOutputType(enum DataType output_type) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310OutputType] = output_type; } enum DataType AscendDeviceInfo::GetOutputType() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionAscend310OutputType); } void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_optimize_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); } std::vector AscendDeviceInfo::GetBufferOptimizeModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310BufferOptimize); return StringToChar(ref); } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/factory.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H #define MINDSPORE_CCSRC_CXX_API_FACTORY_H #include #include #include #include #include #include #include "utils/utils.h" namespace mindspore { inline enum DeviceType g_device_target = kInvalidDeviceType; static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { std::map type_str_map = { {kAscend, "Ascend"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"}, {kGPU, "GPU"}, {kCPU, "CPU"}}; auto it = type_str_map.find(device_type); if (it != type_str_map.end()) { stream << it->second; } else { stream << "[InvalidDeviceType: " << static_cast(device_type) << "]"; } return stream; } template class Factory { using U = std::function()>; public: Factory(const Factory &) = delete; void operator=(const Factory &) = delete; static Factory &Instance() { static Factory instance; return instance; } void Register(U &&creator) { creators_.push_back(creator); } std::shared_ptr Create(enum DeviceType device_type) { for (auto &item : creators_) { MS_EXCEPTION_IF_NULL(item); auto val = item(); if (val->CheckDeviceSupport(device_type)) { return val; } } MS_LOG(WARNING) << "Unsupported device target " << device_type; return nullptr; } private: Factory() = default; ~Factory() = default; std::vector creators_; }; template class Registrar { using U = std::function()>; public: explicit Registrar(U creator) { Factory::Instance().Register(std::move(creator)); } ~Registrar() = default; }; #define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \ static const Registrar g_api_##DERIVE_CLASS##_registrar_reg( \ []() { return std::make_shared(); }); } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H ================================================ FILE: tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cxx_api/graph/ascend/ascend_graph_impl.h" #include #include "include/api/context.h" #include "cxx_api/factory.h" #include "stub/graph_impl_stub.h" namespace mindspore { API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl); AscendGraphImpl::AscendGraphImpl() { graph_imp_stub_ = std::make_shared(); } AscendGraphImpl::~AscendGraphImpl() {} std::vector AscendGraphImpl::GetInputs() { return graph_imp_stub_->GetInputs(); } std::vector AscendGraphImpl::GetOutputs() { return graph_imp_stub_->GetOutputs(); } Status AscendGraphImpl::Load(uint32_t device_id) { graph_imp_stub_->SetGraph(graph_); graph_imp_stub_->SetContext(graph_context_); return graph_imp_stub_->Load(device_id); } Status AscendGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { return graph_imp_stub_->Run(inputs, outputs); } bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return graph_imp_stub_->CheckDeviceSupport(device_type); } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #include #include #include #include #include #include #include "include/api/status.h" #include "include/api/graph.h" #include "cxx_api/graph/graph_impl.h" #include "cxx_api/model/model_impl.h" namespace mindspore { class AscendGraphImpl : public GraphCell::GraphImpl { public: AscendGraphImpl(); ~AscendGraphImpl() override; Status Run(const std::vector &inputs, std::vector *outputs) override; Status Load(uint32_t device_id) override; std::vector GetInputs() override; std::vector GetOutputs() override; bool CheckDeviceSupport(mindspore::DeviceType device_type) override; private: std::shared_ptr graph_imp_stub_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H ================================================ FILE: tests/ut/stub/cxx_api/graph/graph.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/graph.h" #include "cxx_api/graph/graph_data.h" #include "utils/log_adapter.h" namespace mindspore { Graph::Graph() : graph_data_(nullptr) {} Graph::Graph(const std::shared_ptr &graph_data) : graph_data_(graph_data) {} Graph::Graph(std::shared_ptr &&graph_data) : graph_data_(graph_data) {} Graph::~Graph() {} Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } bool Graph::operator!=(std::nullptr_t) const { return graph_data_ != nullptr; } ModelType Graph::ModelType() const { MS_EXCEPTION_IF_NULL(graph_data_); return graph_data_->ModelType(); } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/graph/graph_data.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cxx_api/graph/graph_data.h" #include "utils/log_adapter.h" #ifdef ENABLE_ACL #include "framework/common/helper/model_helper.h" #endif namespace mindspore { Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type) : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { if (model_type != ModelType::kMindIR) { MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; } func_graph_ = func_graph; model_type_ = model_type; } Graph::GraphData::GraphData(const Buffer &om_data, enum ModelType model_type) : func_graph_(nullptr), om_data_(om_data), model_type_(model_type) { if (model_type_ != ModelType::kOM) { MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type_; } #ifdef ENABLE_ACL // check om ge::ModelHelper helper; ge::ModelData model_data; model_data.model_data = om_data_.MutableData(); model_data.model_len = om_data_.DataSize(); ge::Status ret = helper.LoadRootModel(model_data); if (ret != ge::SUCCESS) { MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om."; } #else MS_LOG(EXCEPTION) << "Unsupported ModelType OM."; #endif } Graph::GraphData::~GraphData() {} FuncGraphPtr Graph::GraphData::GetFuncGraph() const { if (model_type_ != ModelType::kMindIR) { MS_LOG(ERROR) << "Invalid ModelType " << model_type_; return nullptr; } return func_graph_; } Buffer Graph::GraphData::GetOMData() const { if (model_type_ != ModelType::kOM) { MS_LOG(ERROR) << "Invalid ModelType " << model_type_; return Buffer(); } return om_data_; } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/graph/graph_data.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H #include #include #include #include #include "include/api/graph.h" #include "include/api/types.h" #include "utils/utils.h" namespace mindspore { class Graph::GraphData { public: GraphData(); explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR); GraphData(const Buffer &om_data, enum ModelType model_type); ~GraphData(); enum ModelType ModelType() const { return model_type_; } FuncGraphPtr GetFuncGraph() const; Buffer GetOMData() const; private: FuncGraphPtr func_graph_; Buffer om_data_; enum ModelType model_type_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H ================================================ FILE: tests/ut/stub/cxx_api/graph/graph_impl.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H #include #include #include #include #include #include #include "include/api/cell.h" #include "include/api/graph.h" #include "include/api/context.h" #include "cxx_api/graph/graph_data.h" #include "utils/utils.h" namespace mindspore { class GraphCell::GraphImpl { public: GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} virtual ~GraphImpl() = default; std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } void SetContext(const std::shared_ptr &context) { graph_context_ = context; } virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; virtual Status Load(uint32_t device_id) = 0; virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; protected: std::shared_ptr graph_; std::shared_ptr graph_context_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H ================================================ FILE: tests/ut/stub/cxx_api/model/model.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/model.h" #include "include/api/context.h" #include "include/api/serialization.h" #include "cxx_api/model/model_impl.h" #include "cxx_api/factory.h" #include "utils/utils.h" namespace mindspore { Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_context, const std::shared_ptr &) { if (graph_cell.GetGraph() == nullptr) { MS_LOG(ERROR) << "Invalid graph input."; return kMCInvalidInput; } if (model_context == nullptr) { MS_LOG(ERROR) << "Invalid model context."; return kMCInvalidInput; } auto &device_info = model_context->MutableDeviceInfo(); if (device_info.size() < 1) { MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; return kMCInvalidInput; } auto device_target = device_info[0]->GetDeviceType(); impl_ = Factory::Instance().Create(device_target); if (impl_ == nullptr) { MS_LOG(ERROR) << "Create session type " << device_target << " failed"; return kMEFailed; } g_device_target = device_target; impl_->SetGraph(std::make_shared(*graph_cell.GetGraph())); impl_->SetContext(model_context); return impl_->Build(); } Status Model::Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::vector &cropto_lib_path) { mindspore::Graph graph; auto status = mindspore::Serialization::Load(CharToString(model_path), model_type, &graph, dec_key, dec_mode); if (!status.IsOk()) { return status; } return Build(GraphCell(graph), model_context); } Status Model::Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context) { mindspore::Graph graph; auto status = mindspore::Serialization::Load(CharToString(model_path), model_type, &graph); if (!status.IsOk()) { return status; } return Build(GraphCell(graph), model_context); } Status Model::Resize(const std::vector &inputs, const std::vector> &dims) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return kMCFailed; } return impl_->Resize(inputs, dims); } Status Model::Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return kMCFailed; } return impl_->Predict(inputs, outputs); } std::vector Model::GetInputs() { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return {}; } return impl_->GetInputs(); } std::vector Model::GetOutputs() { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; return {}; } return impl_->GetOutputs(); } MSTensor Model::GetInputByTensorName(const std::vector &tensor_name) { std::string tensor_name_str = CharToString(tensor_name); auto inputs = GetInputs(); for (auto in : inputs) { if (in.Name() == tensor_name_str) { return in; } } return MSTensor(nullptr); } std::vector> Model::GetOutputTensorNamesChar() { std::vector> ret; auto outputs = GetOutputs(); std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret), [](const MSTensor &item) -> std::vector { return StringToChar(item.Name()); }); return ret; } MSTensor Model::GetOutputByTensorName(const std::vector &tensor_name) { std::string tensor_name_str = CharToString(tensor_name); auto outputs = GetOutputs(); for (auto out : outputs) { if (out.Name() == tensor_name_str) { return out; } } return MSTensor(nullptr); } std::vector Model::GetOutputsByNodeName(const std::vector &node_name) { return std::vector{GetOutputByTensorName(node_name)}; } Model::Model() : impl_(nullptr) {} Model::~Model() {} bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { auto check_model = Factory::Instance().Create(device_type); if (check_model == nullptr) { return false; } return check_model->CheckModelSupport(model_type); } Status Model::LoadConfig(const std::vector &config_path) { if (common::DirOrFileExist(CharToString(config_path))) { return kSuccess; } MS_LOG(ERROR) << "The config file path: " << CharToString(config_path) << " doesn't exist"; return kMCFailed; } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/model/model_impl.cc ================================================ /** * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cxx_api/model/model_impl.h" namespace mindspore { Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); if (graph_ == nullptr) { MS_LOG(ERROR) << "Invalid data, graph_ is null."; return kMCFailed; } if (graph_cell_ == nullptr) { MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; Status ret = Build(); if (ret != kSuccess) { MS_LOG(ERROR) << "Build model failed."; return ret; } } MS_EXCEPTION_IF_NULL(graph_cell_); Status ret = graph_cell_->Run(inputs, outputs); if (ret != kSuccess) { MS_LOG(ERROR) << "Run graph failed."; return ret; } return kSuccess; } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/model/model_impl.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H #define MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H #include #include #include #include #include #include #include "include/api/context.h" #include "include/api/model.h" #include "include/api/graph.h" #include "cxx_api/graph/graph_data.h" #include "utils/utils.h" namespace mindspore { class ModelImpl { public: ModelImpl() = default; virtual ~ModelImpl() = default; virtual Status Build() = 0; virtual Status Resize(const std::vector &inputs, const std::vector> &dims) = 0; virtual Status Predict(const std::vector &inputs, std::vector *outputs); virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; virtual bool CheckModelSupport(enum ModelType model_type) = 0; protected: FuncGraphPtr GetFuncGraph() const { if (graph_->ModelType() != ModelType::kMindIR) { return nullptr; } auto graph_data = graph_->graph_data_; MS_EXCEPTION_IF_NULL(graph_data); return graph_data->GetFuncGraph(); } std::shared_ptr graph_ = nullptr; std::shared_ptr graph_cell_ = nullptr; std::shared_ptr model_context_ = nullptr; private: friend class Model; void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } void SetContext(const std::shared_ptr &model_context) { if (model_context != nullptr) { model_context_ = std::make_shared(*model_context); } } }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H ================================================ FILE: tests/ut/stub/cxx_api/model/ms/ms_model.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cxx_api/model/ms/ms_model.h" #include #include #include "include/api/context.h" #include "cxx_api/factory.h" namespace mindspore { // mindspore-serving check current package for version check with ModelImpl factory. API_FACTORY_REG(ModelImpl, MsModel); static std::string GenerateShapeKey(const std::vector> &dims) { std::string shape_key; for (size_t i = 0; i < dims.size(); ++i) { shape_key += std::to_string(i) + ":"; for (size_t j = 0; j < dims[i].size(); ++j) { shape_key += std::to_string(dims[i][j]); if (j + 1 < dims[i].size()) { shape_key += ","; } } if (i + 1 < dims.size()) { shape_key += ";"; } } return shape_key; } std::shared_ptr MsModel::GenerateGraphCell(const std::vector> &dims) { std::string shape_key = GenerateShapeKey(dims); if (auto iter = dynamic_size_graph_map_.find(shape_key); iter != dynamic_size_graph_map_.end()) { MS_LOG(INFO) << "This options has been built, read cache."; return iter->second; } auto func_graph = ModelImpl::GetFuncGraph(); MS_EXCEPTION_IF_NULL(func_graph); auto graph = std::make_shared(std::make_shared(func_graph, ModelType::kMindIR)); MS_EXCEPTION_IF_NULL(graph); auto graph_cell = std::make_shared(graph); MS_EXCEPTION_IF_NULL(graph_cell); graph_cell->SetContext(model_context_); auto ret = graph_cell->Load(GetDeviceID()); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; return nullptr; } dynamic_size_graph_map_[shape_key] = graph_cell; return graph_cell; } Status MsModel::Build() { MS_LOG(INFO) << "Start build model."; MS_EXCEPTION_IF_NULL(graph_); if (graph_cell_ != nullptr) { MS_LOG(INFO) << "This model has been built, skip."; return kSuccess; } auto func_graph = ModelImpl::GetFuncGraph(); MS_EXCEPTION_IF_NULL(func_graph); auto graph = std::make_shared(std::make_shared(func_graph, ModelType::kMindIR)); MS_EXCEPTION_IF_NULL(graph); auto graph_cell = std::make_shared(graph); MS_EXCEPTION_IF_NULL(graph_cell); graph_cell->SetContext(model_context_); auto ret = graph_cell->Load(GetDeviceID()); if (ret != kSuccess) { MS_LOG(ERROR) << "Load failed."; return ret; } // save result graph_cell_ = graph_cell; MS_LOG(INFO) << "Build model success."; return kSuccess; } Status MsModel::Resize(const std::vector &inputs, const std::vector> &dims) { MS_LOG(INFO) << "Start to resize model"; auto origin_inputs = GetInputs(); if (inputs.size() != origin_inputs.size()) { MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size(); return kMCInvalidInput; } if (inputs.size() != dims.size()) { MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size(); return kMCInvalidInput; } auto graph_cell = GenerateGraphCell(dims); if (graph_cell == nullptr) { MS_LOG(ERROR) << "GenerateGraphCell failed."; return kMCFailed; } MS_LOG(INFO) << "Resize model success."; graph_cell_ = std::move(graph_cell); return kSuccess; } std::vector MsModel::GetInputs() { MS_EXCEPTION_IF_NULL(graph_cell_); return graph_cell_->GetInputs(); } std::vector MsModel::GetOutputs() { MS_EXCEPTION_IF_NULL(graph_cell_); return graph_cell_->GetOutputs(); } uint32_t MsModel::GetDeviceID() const { if (model_context_ == nullptr) { return 0; } auto &device_infos = model_context_->MutableDeviceInfo(); if (device_infos.size() != 1) { return 0; } auto ascend910_info = device_infos[0]->Cast(); if (ascend910_info != nullptr) { return ascend910_info->GetDeviceID(); } auto gpu_info = device_infos[0]->Cast(); if (gpu_info != nullptr) { return gpu_info->GetDeviceID(); } return 0; } bool MsModel::CheckDeviceSupport(enum DeviceType device_type) { const char *cpu_value = ::getenv("SERVING_ENABLE_CPU_DEVICE"); const char *gpu_value = ::getenv("SERVING_ENABLE_GPU_DEVICE"); auto enable_cpu = cpu_value && std::string(cpu_value) == "1"; auto enable_gpu = gpu_value && std::string(gpu_value) == "1"; if (device_type == kCPU) { return enable_cpu; } else if (device_type == kGPU) { return enable_gpu; } return !enable_cpu && !enable_gpu; } bool MsModel::CheckModelSupport(mindspore::ModelType model_type) { static const std::set kSupportedModelMap = {kMindIR}; auto iter = kSupportedModelMap.find(model_type); if (iter == kSupportedModelMap.end()) { return false; } return true; } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/model/ms/ms_model.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_SESSION_SESSION_H #define MINDSPORE_CCSRC_SESSION_SESSION_H #include #include #include #include #include #include #include "include/api/status.h" #include "cxx_api/model/model_impl.h" namespace mindspore { class MsModel : public ModelImpl { public: MsModel() {} ~MsModel() = default; Status Build() override; Status Resize(const std::vector &inputs, const std::vector> &dims) override; std::vector GetInputs() override; std::vector GetOutputs() override; bool CheckDeviceSupport(mindspore::DeviceType device_type) override; bool CheckModelSupport(enum ModelType model_type) override; private: std::shared_ptr GenerateGraphCell(const std::vector> &dims); uint32_t GetDeviceID() const; std::map> dynamic_size_graph_map_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H ================================================ FILE: tests/ut/stub/cxx_api/serialization.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/serialization.h" #include #include #include "cxx_api/graph/graph_data.h" #include "utils/log_adapter.h" namespace mindspore { static Status RealPath(const std::string &file, std::string *realpath_str) { MS_EXCEPTION_IF_NULL(realpath_str); char real_path_mem[PATH_MAX] = {0}; char *real_path_ret = nullptr; #if defined(_WIN32) || defined(_WIN64) real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX); #else real_path_ret = realpath(common::SafeCStr(file), real_path_mem); #endif if (real_path_ret == nullptr) { return Status(kMEInvalidInput, "File: " + file + " does not exist."); } *realpath_str = real_path_mem; return kSuccess; } static Buffer ReadFile(const std::string &file) { Buffer buffer; if (file.empty()) { MS_LOG(ERROR) << "Pointer file is nullptr"; return buffer; } std::string real_path; auto status = RealPath(file, &real_path); if (status != kSuccess) { MS_LOG(ERROR) << status.GetErrDescription(); return buffer; } std::ifstream ifs(real_path); if (!ifs.good()) { MS_LOG(ERROR) << "File: " << real_path << " does not exist"; return buffer; } if (!ifs.is_open()) { MS_LOG(ERROR) << "File: " << real_path << " open failed"; return buffer; } (void)ifs.seekg(0, std::ios::end); size_t size = static_cast(ifs.tellg()); buffer.ResizeData(size); if (buffer.DataSize() != size) { MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path; ifs.close(); return buffer; } (void)ifs.seekg(0, std::ios::beg); (void)ifs.read(reinterpret_cast(buffer.MutableData()), static_cast(size)); ifs.close(); return buffer; } Key::Key(const char *dec_key, size_t key_len) { len = 0; if (key_len >= max_key_len) { MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len; return; } auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len); if (sec_ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret; return; } len = key_len; } Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, const std::vector &dec_mode) { std::stringstream err_msg; if (graph == nullptr) { err_msg << "Output args graph is nullptr."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } if (model_type == kMindIR) { FuncGraphPtr anf_graph = nullptr; try { if (dec_key.len > dec_key.max_key_len) { err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } else if (dec_key.len == 0) { if (IsCipherFile(reinterpret_cast(model_data))) { err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } else { anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(model_data), data_size); } } else { size_t plain_data_size; auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast(model_data), data_size, dec_key.key, dec_key.len, CharToString(dec_mode)); if (plain_data == nullptr) { err_msg << "Load model failed. Please check the valid of dec_key and dec_mode."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(plain_data.get()), plain_data_size); } } catch (const std::exception &) { err_msg << "Load model failed. Please check the valid of dec_key and dec_mode."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } *graph = Graph(std::make_shared(anf_graph, kMindIR)); return kSuccess; } else if (model_type == kOM) { *graph = Graph(std::make_shared(Buffer(model_data, data_size), kOM)); return kSuccess; } err_msg << "Unsupported ModelType " << model_type; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } Status Serialization::Load(const std::vector &file, ModelType model_type, Graph *graph) { return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm)); } Status Serialization::Load(const std::vector &file, ModelType model_type, Graph *graph, const Key &dec_key, const std::vector &dec_mode) { std::stringstream err_msg; if (graph == nullptr) { MS_LOG(ERROR) << "Output args graph is nullptr."; return Status(kMEInvalidInput, "Output args graph is nullptr."); } std::string file_path; auto status = RealPath(CharToString(file), &file_path); if (status != kSuccess) { MS_LOG(ERROR) << status.GetErrDescription(); return status; } if (model_type == kMindIR || model_type == kMindIR_Lite) { FuncGraphPtr anf_graph; if (dec_key.len > dec_key.max_key_len) { err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } else if (dec_key.len == 0 && IsCipherFile(file_path)) { err_msg << "Load model failed. The file may be encrypted, please pass in correct key."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } else { anf_graph = LoadMindIR(file_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode)); } if (anf_graph == nullptr) { err_msg << "Load model failed."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } *graph = Graph(std::make_shared(anf_graph, kMindIR)); return kSuccess; } else if (model_type == kOM) { Buffer data = ReadFile(file_path); if (data.Data() == nullptr) { err_msg << "Read file " << file_path << " failed."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } *graph = Graph(std::make_shared(data, kOM)); return kSuccess; } err_msg << "Unsupported ModelType " << model_type; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } Status Serialization::Load(const std::vector> &files, ModelType model_type, std::vector *graphs, const Key &dec_key, const std::vector &dec_mode) { std::stringstream err_msg; if (graphs == nullptr) { MS_LOG(ERROR) << "Output args graph is nullptr."; return Status(kMEInvalidInput, "Output args graph is nullptr."); } if (files.size() == 1) { std::vector result(files.size()); auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode); *graphs = std::move(result); return ret; } std::vector files_path; for (const auto &file : files) { std::string file_path; auto status = RealPath(CharToString(file), &file_path); if (status != kSuccess) { MS_LOG(ERROR) << status.GetErrDescription(); return status; } files_path.emplace_back(std::move(file_path)); } if (model_type == kMindIR) { if (dec_key.len > dec_key.max_key_len) { err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } auto anf_graphs = LoadMindIRs(files_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode)); if (anf_graphs.size() != files_path.size()) { err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs."; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } std::vector results; for (size_t i = 0; i < anf_graphs.size(); ++i) { if (anf_graphs[i] == nullptr) { if (dec_key.len == 0 && IsCipherFile(files_path[i])) { err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key."; } else { err_msg << "Load model " << files_path[i] << " failed."; } MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } results.emplace_back(std::make_shared(anf_graphs[i], kMindIR)); } *graphs = std::move(results); return kSuccess; } err_msg << "Unsupported ModelType " << model_type; MS_LOG(ERROR) << err_msg.str(); return Status(kMEInvalidInput, err_msg.str()); } Status Serialization::SetParameters(const std::map &, Model *) { MS_LOG(ERROR) << "Unsupported feature."; return kMEFailed; } Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { MS_LOG(ERROR) << "Unsupported feature."; return kMEFailed; } Status Serialization::ExportModel(const Model &, ModelType, const std::vector &, QuantizationType, bool, const std::vector> &output_tensor_name) { MS_LOG(ERROR) << "Unsupported feature."; return kMEFailed; } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/status.cc ================================================ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/status.h" #ifndef ENABLE_ANDROID #include #endif #include #include namespace mindspore { struct Status::Data { enum StatusCode status_code = kSuccess; std::string status_msg; int line_of_code = -1; std::string file_name; std::string err_description; }; Status::Status() : data_(std::make_shared()) {} Status::Status(enum StatusCode status_code, const std::vector &status_msg) : data_(std::make_shared()) { if (data_ == nullptr) { return; } data_->status_msg = CharToString(status_msg); data_->status_code = status_code; } Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra) : data_(std::make_shared()) { if (data_ == nullptr) { return; } data_->status_code = code; data_->line_of_code = line_of_code; if (file_name != nullptr) { data_->file_name = file_name; } data_->err_description = CharToString(extra); std::ostringstream ss; #ifndef ENABLE_ANDROID ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". "; if (!data_->err_description.empty()) { ss << data_->err_description; } ss << "\n"; #endif ss << "Line of code : " << line_of_code << "\n"; if (file_name != nullptr) { ss << "File : " << file_name << "\n"; } data_->status_msg = ss.str(); } enum StatusCode Status::StatusCode() const { if (data_ == nullptr) { return kSuccess; } return data_->status_code; } std::vector Status::ToCString() const { if (data_ == nullptr) { return std::vector(); } return StringToChar(data_->status_msg); } int Status::GetLineOfCode() const { if (data_ == nullptr) { return -1; } return data_->line_of_code; } std::vector Status::GetErrDescriptionChar() const { if (data_ == nullptr) { return std::vector(); } return StringToChar(data_->status_msg); } std::vector Status::CodeAsCString(enum StatusCode c) { static std::map info_map = {{kSuccess, "No error occurs."}, // Core {kCoreFailed, "Common error code."}, // MD {kMDOutOfMemory, "Out of memory"}, {kMDShapeMisMatch, "Shape is incorrect"}, {kMDInterrupted, "Interrupted system call"}, {kMDNoSpace, "No space left on device"}, {kMDPyFuncException, "Exception thrown from PyFunc"}, {kMDDuplicateKey, "Duplicate key"}, {kMDPythonInterpreterFailure, ""}, {kMDTDTPushFailure, "Unexpected error"}, {kMDFileNotExist, "Unexpected error"}, {kMDProfilingError, "Error encountered while profiling"}, {kMDBoundingBoxOutOfBounds, "Unexpected error"}, {kMDBoundingBoxInvalidShape, "Unexpected error"}, {kMDSyntaxError, "Syntax error"}, {kMDTimeOut, "Unexpected error"}, {kMDBuddySpaceFull, "BuddySpace full"}, {kMDNetWorkError, "Network error"}, {kMDNotImplementedYet, "Unexpected error"}, {kMDUnexpectedError, "Unexpected error"}, // ME {kMEFailed, "Common error code."}, {kMEInvalidInput, "Invalid input."}, // MC {kMCFailed, "Common error code."}, {kMCDeviceError, "Device error."}, {kMCInvalidInput, "Invalid input."}, {kMCInvalidArgs, "Invalid arguments."}, // Lite {kLiteError, "Common error code."}, {kLiteNullptr, "NULL pointer returned."}, {kLiteParamInvalid, "Invalid parameter."}, {kLiteNoChange, "No change."}, {kLiteSuccessExit, "No error but exit."}, {kLiteMemoryFailed, "Fail to create memory."}, {kLiteNotSupport, "Fail to support."}, {kLiteThreadPoolError, "Thread pool error."}, {kLiteOutOfTensorRange, "Failed to check range."}, {kLiteInputTensorError, "Failed to check input tensor."}, {kLiteReentrantError, "Exist executor running."}, {kLiteGraphFileError, "Failed to verify graph file."}, {kLiteNotFindOp, "Failed to find operator."}, {kLiteInvalidOpName, "Invalid operator name."}, {kLiteInvalidOpAttr, "Invalid operator attr."}, {kLiteOpExecuteFailure, "Failed to execution operator."}, {kLiteFormatError, "Failed to checking tensor format."}, {kLiteInferError, "Failed to infer shape."}, {kLiteInferInvalid, "Invalid infer shape before runtime."}, {kLiteInputParamInvalid, "Invalid input param by user."}}; auto iter = info_map.find(c); return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second); } std::ostream &operator<<(std::ostream &os, const Status &s) { os << s.ToString(); return os; } std::vector Status::SetErrDescription(const std::vector &err_description) { if (data_ == nullptr) { return std::vector(); } data_->err_description = CharToString(err_description); std::ostringstream ss; #ifndef ENABLE_ANDROID ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". "; if (!data_->err_description.empty()) { ss << data_->err_description; } ss << "\n"; #endif if (data_->line_of_code > 0 && !data_->file_name.empty()) { ss << "Line of code : " << data_->line_of_code << "\n"; ss << "File : " << data_->file_name << "\n"; } data_->status_msg = ss.str(); return StringToChar(data_->status_msg); } bool Status::operator==(const Status &other) const { if (data_ == nullptr && other.data_ == nullptr) { return true; } if (data_ == nullptr || other.data_ == nullptr) { return false; } return data_->status_code == other.data_->status_code; } bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; } bool Status::operator!=(const Status &other) const { return !operator==(other); } bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); } Status::operator bool() const { return (StatusCode() == kSuccess); } Status::operator int() const { return static_cast(StatusCode()); } Status Status::OK() { return StatusCode::kSuccess; } bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); } bool Status::IsError() const { return !IsOk(); } } // namespace mindspore ================================================ FILE: tests/ut/stub/cxx_api/types.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/types.h" #include #include #include "securec/include/securec.h" #include "utils/utils.h" namespace mindspore { class Buffer::Impl { public: Impl() : data_() {} ~Impl() = default; Impl(const void *data, size_t data_len) { if (data != nullptr) { (void)SetData(data, data_len); } else { ResizeData(data_len); } } const void *Data() const { return data_.data(); } void *MutableData() { return data_.data(); } size_t DataSize() const { return data_.size(); } bool ResizeData(size_t data_len) { data_.resize(data_len); return true; } bool SetData(const void *data, size_t data_len) { ResizeData(data_len); if (DataSize() != data_len) { MS_LOG(ERROR) << "Set data failed, tensor current data size " << DataSize() << " not match data len " << data_len; return false; } if (data == nullptr) { return data_len == 0; } if (MutableData() == nullptr) { MS_LOG(ERROR) << "Set data failed, data len " << data_len; return false; } auto ret = memcpy_s(MutableData(), DataSize(), data, data_len); if (ret != 0) { MS_LOG(ERROR) << "Set data memcpy_s failed, ret = " << ret; return false; } return true; } protected: std::vector data_; }; class TensorDefaultImpl : public MSTensor::Impl { public: TensorDefaultImpl() : buffer_(), name_(), type_(DataType::kTypeUnknown), shape_() {} ~TensorDefaultImpl() override = default; TensorDefaultImpl(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) : buffer_(data, data_len), name_(name), type_(type), shape_(shape) {} const std::string &Name() const override { return name_; } enum DataType DataType() const override { return type_; } const std::vector &Shape() const override { return shape_; } std::shared_ptr Data() const override { return std::shared_ptr(buffer_.Data(), [](const void *) {}); } void *MutableData() override { return buffer_.MutableData(); } size_t DataSize() const override { return buffer_.DataSize(); } bool IsDevice() const override { return false; } std::shared_ptr Clone() const override { return std::make_shared(name_, type_, shape_, buffer_.Data(), buffer_.DataSize()); } private: Buffer buffer_; std::string name_; enum DataType type_; std::vector shape_; }; class TensorReferenceImpl : public MSTensor::Impl { public: TensorReferenceImpl() : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_(), is_device_(false) {} ~TensorReferenceImpl() override = default; TensorReferenceImpl(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len, bool is_device) : data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape), is_device_(is_device) {} const std::string &Name() const override { return name_; } enum DataType DataType() const override { return type_; } const std::vector &Shape() const override { return shape_; } std::shared_ptr Data() const override { return std::shared_ptr(data_, [](const void *) {}); } void *MutableData() override { return const_cast(data_); } size_t DataSize() const override { return data_size_; } bool IsDevice() const override { return is_device_; } std::shared_ptr Clone() const override { return std::make_shared(name_, type_, shape_, data_, data_size_, is_device_); } protected: const void *data_; size_t data_size_; std::string name_; enum DataType type_; std::vector shape_; bool is_device_; }; MSTensor *MSTensor::CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { std::string name_str = CharToString(name); try { std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len); MSTensor *ret = new MSTensor(impl); return ret; } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; return nullptr; } catch (...) { MS_LOG(ERROR) << "Unknown error occurred."; return nullptr; } } MSTensor *MSTensor::CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len, bool) noexcept { std::string name_str = CharToString(name); try { std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len, false); MSTensor *ret = new MSTensor(impl); return ret; } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; return nullptr; } catch (...) { MS_LOG(ERROR) << "Unknown error occurred."; return nullptr; } } MSTensor MSTensor::CreateDeviceTensor(const std::vector &name, enum DataType type, const std::vector &shape, void *data, size_t data_len) noexcept { std::string name_str = CharToString(name); try { std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len, true); return MSTensor(impl); } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; return MSTensor(nullptr); } catch (...) { MS_LOG(ERROR) << "Unknown error occurred."; return MSTensor(nullptr); } } MSTensor *MSTensor::CharStringsToTensor(const std::vector &name, const std::vector> &str) { // num(4 bytes) + offset1(4 bytes) + offset2(4 bytes) + ... + data1(str1.len) + data2(str2.len) + ... // str1.len() = offset2 - offset1 // data1.begin() = start + offset1 size_t mem_size = 0; mem_size += sizeof(int32_t); // for num for (const auto &s : str) { mem_size += sizeof(int32_t); // for offset mem_size += s.size(); // for data } auto tensor = CreateTensor(name, DataType::kObjectTypeString, {static_cast(mem_size)}, nullptr, mem_size); if (tensor == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; return nullptr; } int32_t *data = reinterpret_cast(tensor->MutableData()); if (data == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; DestroyTensorPtr(tensor); return nullptr; } uint8_t *cur_data = reinterpret_cast(data + 1 + str.size()); *reinterpret_cast(data) = str.size(); for (size_t i = 0; i < str.size(); ++i) { int32_t offset = (cur_data - reinterpret_cast(data)); data[i + 1] = offset; if (str[i].empty()) { continue; } auto ret = memcpy_s(reinterpret_cast(cur_data), str[i].size(), str[i].data(), str[i].size()); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret; DestroyTensorPtr(tensor); return nullptr; } cur_data += str[i].size(); } return tensor; } std::vector> MSTensor::TensorToStringChars(const MSTensor &tensor) { if (tensor == nullptr || tensor.DataType() != DataType::kObjectTypeString || tensor.DataSize() < 4) { MS_LOG(ERROR) << "Invalid tensor."; return {}; } std::vector> strings; auto host_data = tensor.Data(); const int32_t *data = reinterpret_cast(host_data.get()); int32_t str_num = data[0]; if (str_num == 0) { return {}; } if (str_num < 0) { MS_LOG(ERROR) << "str num " << str_num << " cannot be negative."; return {}; } if (tensor.DataSize() < (str_num + 1) * sizeof(int32_t)) { MS_LOG(ERROR) << "Invalid tensor data size " << tensor.DataSize() << ", need " << IntToSize(str_num + 1) * sizeof(int32_t) << " at least for " << str_num << " strings."; return {}; } for (size_t i = 0; i < static_cast(str_num); ++i) { strings.push_back({}); auto &str = strings[i]; int32_t str_len; int32_t offset = data[i + 1]; if (i + 1 != static_cast(str_num)) { str_len = data[i + 1 + 1] - offset; } else { str_len = tensor.DataSize() - offset; } if (str_len == 0) { continue; } if (str_len < 0) { MS_LOG(ERROR) << "str " << i << " len " << str_len << " cannot be negative."; return {}; } str.resize(str_len); const uint8_t *cur_data = reinterpret_cast(data) + offset; auto ret = memcpy_s(reinterpret_cast(str.data()), str.size(), cur_data, str_len); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret; return {}; } } return strings; } void MSTensor::DestroyTensorPtr(MSTensor *tensor) noexcept { if (tensor != nullptr) { delete tensor; } } MSTensor::MSTensor() : impl_(std::make_shared()) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(const std::shared_ptr &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); } MSTensor::MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) : impl_(std::make_shared(CharToString(name), type, shape, data, data_len)) {} MSTensor::~MSTensor() = default; bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } bool MSTensor::operator!=(std::nullptr_t) const { return impl_ != nullptr; } MSTensor *MSTensor::Clone() const { MS_EXCEPTION_IF_NULL(impl_); try { MSTensor *ret = new MSTensor(); ret->impl_ = impl_->Clone(); return ret; } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; return nullptr; } catch (...) { MS_LOG(ERROR) << "Unknown error occurred."; return nullptr; } } std::vector MSTensor::CharName() const { MS_EXCEPTION_IF_NULL(impl_); return StringToChar(impl_->Name()); } enum DataType MSTensor::DataType() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->DataType(); } const std::vector &MSTensor::Shape() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->Shape(); } int64_t MSTensor::ElementNum() const { MS_EXCEPTION_IF_NULL(impl_); const auto &shape = impl_->Shape(); if (shape.empty()) { // element number of scalar is 1 return 1; } return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } std::shared_ptr MSTensor::Data() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->Data(); } void *MSTensor::MutableData() { MS_EXCEPTION_IF_NULL(impl_); return impl_->MutableData(); } size_t MSTensor::DataSize() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->DataSize(); } bool MSTensor::IsDevice() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->IsDevice(); } void MSTensor::SetShape(const std::vector &) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetTensorName(const std::vector &) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetAllocator(std::shared_ptr) { MS_LOG_EXCEPTION << "Invalid implement."; } std::shared_ptr MSTensor::allocator() const { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetFormat(mindspore::Format) { MS_LOG_EXCEPTION << "Invalid implement."; } mindspore::Format MSTensor::format() const { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetData(void *, bool) { MS_LOG_EXCEPTION << "Invalid implement."; } std::vector MSTensor::QuantParams() const { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetQuantParams(std::vector) { MS_LOG_EXCEPTION << "Invalid implement."; } Buffer::Buffer() : impl_(std::make_shared()) {} Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared(data, data_len)) {} Buffer::~Buffer() = default; Buffer Buffer::Clone() const { MS_EXCEPTION_IF_NULL(impl_); Buffer ret; ret.impl_ = std::make_shared(*impl_); return ret; } const void *Buffer::Data() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->Data(); } void *Buffer::MutableData() { MS_EXCEPTION_IF_NULL(impl_); return impl_->MutableData(); } size_t Buffer::DataSize() const { MS_EXCEPTION_IF_NULL(impl_); return impl_->DataSize(); } bool Buffer::ResizeData(size_t data_len) { MS_EXCEPTION_IF_NULL(impl_); return impl_->ResizeData(data_len); } bool Buffer::SetData(const void *data, size_t data_len) { MS_EXCEPTION_IF_NULL(impl_); return impl_->SetData(data, data_len); } std::vector CharVersion() { return {}; } } // namespace mindspore ================================================ FILE: tests/ut/stub/graph_impl_stub.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "stub/graph_impl_stub.h" namespace mindspore { void GraphImplStubAdd::Init(const std::vector &add_shape) { auto element_cnt = [add_shape]() -> size_t { size_t element_num = 1; for (auto dim : add_shape) { if (dim <= 0) { return 0; } element_num *= dim; } return element_num; }; auto ele_size = element_cnt() * sizeof(float); inputs_.clear(); for (size_t i = 0; i < input_count; i++) { MSTensor tensor_x = MSTensor("x" + std::to_string(1), mindspore::DataType::kNumberTypeFloat32, add_shape, nullptr, ele_size); inputs_.push_back(tensor_x); } outputs_.clear(); for (size_t i = 0; i < output_count; i++) { MSTensor tensor_y = MSTensor("x" + std::to_string(1), mindspore::DataType::kNumberTypeFloat32, add_shape, nullptr, ele_size); outputs_.push_back(tensor_y); } } // y=x1+x2+x3+x4, y=x1-x2-x3-x4 // y2=y1+1 Status GraphImplStubAdd::Run(const std::vector &inputs, std::vector *outputs) { auto file_name = graph_->graph_data_->GetFuncGraph()->file_name_; MS_LOG_INFO << "exec model file ------------------- " << file_name; if (inputs.size() != inputs_.size()) { return mindspore::kCoreFailed; } for (size_t i = 0; i < inputs.size(); i++) { if (inputs[i].DataSize() != inputs_[i].DataSize()) { return mindspore::kCoreFailed; } if (inputs_[i].DataSize() != 0 && inputs[i].Data() == nullptr) { return mindspore::kCoreFailed; } } auto item_count = outputs_[0].DataSize() / sizeof(float); auto get_output_tensor = [this](size_t index) -> MSTensor { MSTensor *output_ptr = outputs_[index].Clone(); MSTensor output = *output_ptr; mindspore::MSTensor::DestroyTensorPtr(output_ptr); return output; }; auto output = get_output_tensor(0); auto y = reinterpret_cast(output.MutableData()); auto x0 = reinterpret_cast(inputs[0].Data().get()); for (size_t i = 0; i < item_count; i++) { y[i] = x0[i]; } for (size_t k = 1; k < input_count; k++) { auto xk = reinterpret_cast(inputs[k].Data().get()); for (size_t i = 0; i < item_count; i++) { if (sub_) { y[i] = y[i] - xk[i]; } else { y[i] = y[i] + xk[i]; } } } outputs->push_back(output); for (size_t k = 1; k < output_count; k++) { auto output_k = get_output_tensor(k); auto yk = reinterpret_cast(output_k.MutableData()); for (size_t i = 0; i < item_count; i++) { yk[i] = y[i] + k; } outputs->push_back(output_k); } return mindspore::kSuccess; } Status GraphImplStubAdd::Load(uint32_t device_id) { LoadInner(); auto status = CheckContext(); if (!status.IsOk()) { return status; } if (input_count == 0 || output_count == 0) { MS_LOG_ERROR << "Invalid input count or output count, input count: " << input_count << ", output count: " << output_count; return kCoreFailed; } MS_LOG_INFO << "input count: " << input_count << ", output count: " << output_count; Init({2, 2}); return kSuccess; } Status GraphImplStubAdd::CheckContext() { auto file_name = graph_->graph_data_->GetFuncGraph()->file_name_; bool enable_lite = false; if (file_name.find("lite") != std::string::npos) { enable_lite = true; } auto device_info_list = graph_context_->MutableDeviceInfo(); if (!enable_lite && device_info_list.size() > 1) { return kCoreFailed; } auto beg = file_name.find('@'); if (beg == std::string::npos) { return kSuccess; } auto device_beg = file_name.find('_', beg); std::stringstream ss(file_name.substr(device_beg + 1)); std::vector device_list; std::string device_info; while (std::getline(ss, device_info, '_')) { device_list.push_back(device_info); } if (device_list.size() != device_info_list.size()) { return kCoreFailed; } std::map device_type_map{ {"cpu", kCPU}, {"gpu", kGPU}, {"ascend", kAscend}}; for (size_t i = 0; i < device_list.size(); ++i) { if (device_type_map[device_list[i]] != device_info_list[i]->GetDeviceType()) { return kCoreFailed; } } return kSuccess; } void GraphImplStubAdd::LoadInner() { auto file_name = graph_->graph_data_->GetFuncGraph()->file_name_; MS_LOG_INFO << "model file ------------------- " << file_name; auto beg = file_name.find("tensor_add"); // tensor_add_2_2.mindir or tensor_sub_2_2.mindir if (beg == std::string::npos) { beg = file_name.find("tensor_sub"); if (beg == std::string::npos) { return; } sub_ = true; } beg += std::string("tensor_add").size(); auto input_beg = file_name.find("_", beg); if (input_beg == std::string::npos) { return; } auto output_beg = file_name.find("_", input_beg + 1); if (output_beg == std::string::npos) { return; } auto dot_beg = file_name.find(".mindir", output_beg + 1); if (dot_beg == std::string::npos) { return; } input_count = std::stoi(file_name.substr(input_beg + 1, output_beg)); output_count = std::stoi(file_name.substr(output_beg + 1, dot_beg)); } std::vector GraphImplStubAdd::GetInputs() { return inputs_; } std::vector GraphImplStubAdd::GetOutputs() { return outputs_; } bool GraphImplStubAdd::CheckDeviceSupport(mindspore::DeviceType device_type) { if (device_type == kCPU) { const char *value = ::getenv("SERVING_ENABLE_CPU_DEVICE"); if (value == nullptr || std::string(value) != "1") { return false; } } else if (device_type == kGPU) { const char *value = ::getenv("SERVING_ENABLE_GPU_DEVICE"); if (value == nullptr || std::string(value) != "1") { return false; } } else if (device_type == kAscend || device_type == kAscend310 || device_type == kAscend910) { const char *value = ::getenv("SERVING_ENABLE_GPU_DEVICE"); if (value == nullptr || std::string(value) != "1") { return true; } } return true; } } // namespace mindspore ================================================ FILE: tests/ut/stub/graph_impl_stub.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_SERVING_GRAPH_IMPL_STUB_H #define MINDSPORE_SERVING_GRAPH_IMPL_STUB_H #include #include #include #include #include #include #include "include/api/status.h" #include "include/api/graph.h" #include "cxx_api/graph/graph_impl.h" #include "cxx_api/model/model_impl.h" namespace mindspore { class GraphImplStubAdd : public GraphCell::GraphImpl { public: GraphImplStubAdd() = default; ~GraphImplStubAdd() = default; Status Run(const std::vector &inputs, std::vector *outputs) override; Status Load(uint32_t device_id) override; std::vector GetInputs() override; std::vector GetOutputs() override; bool CheckDeviceSupport(mindspore::DeviceType device_type) override; private: std::vector inputs_; std::vector outputs_; uint64_t input_count = 2; uint64_t output_count = 1; bool sub_ = false; // add or sub op void Init(const std::vector &add_shape); void LoadInner(); Status CheckContext(); }; } // namespace mindspore #endif // MINDSPORE_SERVING_GRAPH_IMPL_STUB_H ================================================ FILE: tests/ut/stub/include/api/allocator.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_ALLOCATOR_H #define MINDSPORE_INCLUDE_API_ALLOCATOR_H #include #include "include/api/types.h" namespace mindspore { /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically. class MS_API Allocator { public: /// \brief Destructor of MindSpore Allocator. virtual ~Allocator() = default; /// \brief Method to request memory. /// /// \param[in] size Define the memory size to request. virtual void *Malloc(size_t size) = 0; /// \brief Method to request memory. /// /// \param[in] weight Defines the width of memory to request /// \param[in] height Defines the height of memory to request /// \param[in] type Defines the data type of memory to request virtual void *Malloc(size_t weight, size_t height, DataType type) { return nullptr; } /// \brief Method to free memory. /// /// \param[in] ptr Define the pointer of a certain memory. virtual void Free(void *ptr) = 0; /// \brief Reference count of a certain memory. /// /// \param[in] ptr Define the pointer of a certain memory. /// /// \return Reference count of a certain memory currently. virtual int RefCount(void *ptr) = 0; /// \brief Set reference count of a certain memory. /// /// \param[in] ptr Define the pointer of a certain memory. /// \param[in] ref_count Define the reference count to set. /// /// \return Reference count of a certain memory after setting. virtual int SetRefCount(void *ptr, int ref_count) = 0; /// \brief Decrease the reference count of a certain memory. /// /// \param[in] ptr Define the pointer of a certain memory. /// \param[in] ref_count Define the reference count to reduce. /// /// \return Reference count of a certain memory after decreating. virtual int DecRefCount(void *ptr, int ref_count) = 0; /// \brief Increase the reference count of a certain memory. /// /// \param[in] ptr Define the pointer of a certain memory. /// \param[in] ref_count Define the reference count to increase. /// /// \return Reference count of a certain memory after increasing. virtual int IncRefCount(void *ptr, int ref_count) = 0; /// \brief Static method to create an allocator. /// /// \return Smart pointer of an allocator. static std::shared_ptr Create(); /// \brief Prepare a certain memory. /// /// \param[in] ptr Define the pointer of a certain memory to prepare. /// /// \return Pointer of ready memory. virtual void *Prepare(void *ptr) { return ptr; } protected: // memory aligned bytes size_t aligned_size_ = 32; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_ALLOCATOR_H ================================================ FILE: tests/ut/stub/include/api/callback/callback.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H #define MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H #include #include #include #include #include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" namespace mindspore { class Model; class ModelImpl; class CallbackImpl; using GraphPoint = std::pair; struct TrainCallBackData { TrainCallBackData(bool train_mode, int epoch, int step, Model *model): train_mode_(train_mode), epoch_(epoch), step_(step), model_(model) {} bool train_mode_; /**< training mode of LiteSession object */ unsigned int epoch_; /**< the current training epoch (starts at 0) */ unsigned int step_ = 0; /**< the current step within the epoch */ Model *model_; /**< pointer to the Model object */ }; enum CallbackRetValue : uint32_t { kContinue = 0, kStopTraining = 1, kExit = 2, kUnknownRetValue = 0xFFFFFFFF }; class TrainCallBack { public: virtual ~TrainCallBack() = default; /// \brief This method is called once before the network executing /// /// \param[in] cb_data info about current execution virtual void Begin(const TrainCallBackData &cb_data) {} /// \brief This method is called once following the network execution /// /// \param[in] cb_data info about current execution virtual void End(const TrainCallBackData &cb_data) {} /// \brief This method is called at the beginning of each epoch /// /// \param[in] cb_data info about current execution virtual void EpochBegin(const TrainCallBackData &cb_data) {} /// \brief This method is called after the run of each epoch /// /// \param[in] cb_data info about current execution /// /// \return indication if to continue in the train loop: /// RET_CONTINUE -- continue training /// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy) /// RET_EXIT -- Exit training (due to error of some sort) virtual CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) { return kContinue; } /// \brief This method is called at the beginning of each step /// /// \param[in] cb_data info about current execution virtual void StepBegin(const TrainCallBackData &cb_data) {} /// \brief This method is called after each step is ran /// /// \param[in] cb_data info about current execution virtual void StepEnd(const TrainCallBackData &cb_data) {} protected: friend class Model; friend class ModelImpl; CallbackImpl* callback_impl_ = nullptr; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_CALLBACK_H ================================================ FILE: tests/ut/stub/include/api/callback/ckpt_saver.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H #define MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H #include #include #include #include #include "include/api/callback/callback.h" #include "include/api/dual_abi_helper.h" namespace mindspore { class CkptSaver: public TrainCallBack { public: inline CkptSaver(int save_every_n, const std::string &filename_prefix); virtual ~CkptSaver(); private: CkptSaver(int save_every_n, const std::vector &filename_prefix); }; CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) : CkptSaver(save_every_n, StringToChar(filename_prefix)) {} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H ================================================ FILE: tests/ut/stub/include/api/callback/loss_monitor.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H #define MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H #include #include #include #include "include/api/callback/callback.h" namespace mindspore { class LossMonitor: public TrainCallBack { public: explicit LossMonitor(int print_every_n_steps = INT_MAX); virtual ~LossMonitor(); const std::vector &GetLossPoints(); }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_LOSS_MONITOR_H ================================================ FILE: tests/ut/stub/include/api/callback/lr_scheduler.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H #define MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H #include #include #include #include #include "include/api/callback/callback.h" namespace mindspore { constexpr int DONT_UPDATE_LR = 0; constexpr int UPDATE_LR = 1; using LR_Lambda = std::function; /// \brief Multiply the LR by a factor of gamma every epoch int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication); /// \brief Multiply the LR by a factor of gamma every step_size int StepLRLambda(float *lr, int epoch, void *step_size); struct StepLRLambda { StepLRLambda(int step, float g) : step_size(step), gamma(g) {} int step_size; // period of LR decay float gamma; // LR decay factor }; class LRScheduler: public TrainCallBack { public: explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step = 1); virtual ~LRScheduler(); }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_LR_SCHEDULER_H ================================================ FILE: tests/ut/stub/include/api/callback/time_monitor.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H #define MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H #include #include #include #include #include "include/api/callback/callback.h" namespace mindspore { class TimeMonitor: public TrainCallBack { public: virtual ~TimeMonitor() = default; void EpochBegin(const TrainCallBackData &cb_data) override; CallbackRetValue EpochEnd(const TrainCallBackData &cb_data) override; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_TIME_MONITOR_H ================================================ FILE: tests/ut/stub/include/api/callback/train_accuracy.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H #define MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H #include #include #include #include #include #include "include/api/callback/callback.h" #include "include/api/metrics/accuracy.h" namespace mindspore { class TrainAccuracy: public TrainCallBack { public: explicit TrainAccuracy(int print_every_n = INT_MAX, int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector &input_indexes = {1}, const std::vector &output_indexes = {0}); virtual ~TrainAccuracy(); const std::vector &GetAccuracyPoints(); }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_TRAIN_ACCURACY_H ================================================ FILE: tests/ut/stub/include/api/cell.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CELL_H #define MINDSPORE_INCLUDE_API_CELL_H #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" #include "include/api/graph.h" namespace mindspore { class InputAndOutput; class Context; using Input = InputAndOutput; using Output = InputAndOutput; class MS_API CellBase { public: CellBase() = default; virtual ~CellBase() = default; virtual std::vector Construct(const std::vector &inputs) { return {}; } virtual std::shared_ptr Clone() const = 0; virtual Status Run(const std::vector &inputs, std::vector *outputs) { return kSuccess; } std::vector operator()(const std::vector &inputs) const; }; template class MS_API Cell : public CellBase { public: virtual ~Cell() = default; std::shared_ptr Clone() const override { return std::make_shared(static_cast(*this)); } }; class MS_API ParameterCell final : public Cell { public: ParameterCell() = default; ~ParameterCell() override = default; ParameterCell(const ParameterCell &); ParameterCell &operator=(const ParameterCell &); ParameterCell(ParameterCell &&); ParameterCell &operator=(ParameterCell &&); explicit ParameterCell(const MSTensor &); ParameterCell &operator=(const MSTensor &); explicit ParameterCell(MSTensor &&); ParameterCell &operator=(MSTensor &&); MSTensor GetTensor() const { return tensor_; } private: MSTensor tensor_; }; class MS_API OpCellBase : public CellBase { public: explicit OpCellBase(const std::string &name) : name_(name) {} ~OpCellBase() override = default; const std::string &GetOpType() const { return name_; } protected: std::string name_; }; template class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this { public: explicit OpCell(const std::string &name) : OpCellBase(name) {} ~OpCell() override = default; std::shared_ptr Clone() const override { return std::make_shared(static_cast(*this)); } }; class MS_API GraphCell final : public Cell { public: class GraphImpl; GraphCell() = default; ~GraphCell() override = default; explicit GraphCell(const Graph &); explicit GraphCell(Graph &&); explicit GraphCell(const std::shared_ptr &); void SetContext(const std::shared_ptr &context); const std::shared_ptr &GetGraph() const { return graph_; } Status Run(const std::vector &inputs, std::vector *outputs) override; std::vector GetInputs(); std::vector GetOutputs(); Status Load(uint32_t device_id); private: friend class Model; std::shared_ptr graph_; std::shared_ptr executor_; }; class MS_API InputAndOutput { public: InputAndOutput(); ~InputAndOutput() = default; // no explicit InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit) InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit) InputAndOutput(const std::shared_ptr &, const std::vector &, int32_t index); int32_t GetIndex() const { return index_; } void SetIndex(int32_t index) { index_ = index; } private: std::shared_ptr cell_; std::vector prev_; int32_t index_; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CELL_H ================================================ FILE: tests/ut/stub/include/api/cfg.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CFG_H #define MINDSPORE_INCLUDE_API_CFG_H #include #include #include #include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" #include "include/api/types.h" namespace mindspore { class MixPrecisionCfg { public: MixPrecisionCfg() { this->dynamic_loss_scale_ = false; this->loss_scale_ = 128.0f; this->num_of_not_nan_iter_th_ = 1000; } ~MixPrecisionCfg() = default; bool dynamic_loss_scale_ = false; /**< Enable\disable dynamic loss scale during mix precision training */ float loss_scale_; /**< Initial loss scale factor */ uint32_t num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */ bool is_raw_mix_precision_ = false; /**< Is mix precision model export from mindspore */ }; class TrainCfg { public: TrainCfg() { this->loss_name_ = "_loss_fn"; } ~TrainCfg() = default; OptimizationLevel optimization_level_ = kO0; std::string loss_name_; /**< Set part of the name that identify a loss kernel */ MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ bool accumulate_gradients_ = false; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CFG_H ================================================ FILE: tests/ut/stub/include/api/context.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H #include #include #include #include #include "include/api/types.h" #include "include/api/dual_abi_helper.h" namespace mindspore { enum DeviceType { kCPU = 0, kGPU, kKirinNPU, kAscend, kAscend910, kAscend310, // add new type here kInvalidDeviceType = 100, }; class Allocator; class Delegate; class DeviceInfoContext; /// \brief Context is used to store environment variables during execution. class MS_API Context { public: struct Data; Context(); ~Context() = default; /// \brief Set the number of threads at runtime. Only valid for Lite. /// /// \param[in] thread_num the number of threads at runtime. void SetThreadNum(int32_t thread_num); /// \brief Get the current thread number setting. Only valid for Lite. /// /// \return The current thread number setting. int32_t GetThreadNum() const; /// \brief Set the thread affinity to CPU cores. Only valid for Lite. /// /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first void SetThreadAffinity(int mode); /// \brief Get the thread affinity of CPU cores. Only valid for Lite. /// /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first int GetThreadAffinityMode() const; /// \brief Set the thread lists to CPU cores. Only valid for Lite. /// /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the /// mode is not effective. /// /// \param[in] core_list: a vector of thread core lists. void SetThreadAffinity(const std::vector &core_list); /// \brief Get the thread lists of CPU cores. Only valid for Lite. /// /// \return core_list: a vector of thread core lists. std::vector GetThreadAffinityCoreList() const; /// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite. /// /// \param[in] is_parallel: true, parallel; false, not in parallel. void SetEnableParallel(bool is_parallel); /// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite. /// /// \return Bool value that indicates whether in parallel. bool GetEnableParallel() const; /// \brief Set Delegate to access third-party AI framework. Only valid for Lite. /// /// \param[in] Pointer to the custom delegate. void SetDelegate(const std::shared_ptr &delegate); /// \brief Get the delegate of the third-party AI framework. Only valid for Lite. /// /// \return Pointer to the custom delegate. std::shared_ptr GetDelegate() const; /// \brief Set quant model to run as float model in multi device. /// /// \param[in] float_mode: true, run as float model; false, not run as float model. void SetMultiModalHW(bool float_mode); /// \brief Get the mode of the model run. /// /// \return Bool value that indicates whether run as float model bool GetMultiModalHW() const; /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports /// heterogeneous scenarios with multiple members in the vector. /// /// \return Mutable reference of DeviceInfoContext vector in this context. std::vector> &MutableDeviceInfo(); private: std::shared_ptr data_; }; /// \brief DeviceInfoContext defines different device contexts. class MS_API DeviceInfoContext : public std::enable_shared_from_this { public: struct Data; DeviceInfoContext(); virtual ~DeviceInfoContext() = default; /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. virtual enum DeviceType GetDeviceType() const = 0; /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails. /// /// \param T Type /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr. template std::shared_ptr Cast() { static_assert(std::is_base_of::value, "Wrong cast type."); if (GetDeviceType() != T().GetDeviceType()) { return nullptr; } return std::static_pointer_cast(shared_from_this()); } /// \brief obtain provider's name /// /// \return provider's name. inline std::string GetProvider() const; /// \brief set provider's name. /// /// \param[in] provider define the provider's name. inline void SetProvider(const std::string &provider); /// \brief obtain provider's device type. /// /// \return provider's device type. inline std::string GetProviderDevice() const; /// \brief set provider's device type. /// /// \param[in] device define the provider's device type.EG: CPU. inline void SetProviderDevice(const std::string &device); /// \brief set memory allocator. /// /// \param[in] allocator define the memory allocator which can be defined by user. void SetAllocator(const std::shared_ptr &allocator); /// \brief obtain memory allocator. /// /// \return memory allocator. std::shared_ptr GetAllocator() const; protected: std::vector GetProviderChar() const; void SetProvider(const std::vector &provider); std::vector GetProviderDeviceChar() const; void SetProviderDevice(const std::vector &device); std::shared_ptr data_; }; std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); } void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); } std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); } void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); } /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid /// for MindSpore Lite. class MS_API CPUDeviceInfo : public DeviceInfoContext { public: /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; /// \brief Set enables to perform the float16 inference /// /// \param[in] is_fp16 Enable float16 inference or not. void SetEnableFP16(bool is_fp16); /// \brief Get enables to perform the float16 inference /// /// \return Whether enable float16 inference. bool GetEnableFP16() const; }; /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid /// for MindSpore Lite. class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { public: /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; /// \brief Set the NPU frequency. /// /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme /// performance), default as 3. void SetFrequency(int frequency); /// \brief Get the NPU frequency. /// /// \return NPU frequency int GetFrequency() const; }; /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU. class MS_API GPUDeviceInfo : public DeviceInfoContext { public: /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; /// \brief Set device id. /// /// \param[in] device_id The device id. void SetDeviceID(uint32_t device_id); /// \brief Get the device id. /// /// \return The device id. uint32_t GetDeviceID() const; /// \brief Get the distribution rank id. /// /// \return The device id. int GetRankID() const; /// \brief Get the distribution group size. /// /// \return The device id. int GetGroupSize() const; /// \brief Set the precision mode. /// /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. inline void SetPrecisionMode(const std::string &precision_mode); /// \brief Get the precision mode. /// /// \return The precision mode. inline std::string GetPrecisionMode() const; /// \brief Set enables to perform the float16 inference /// /// \param[in] is_fp16 Enable float16 inference or not. void SetEnableFP16(bool is_fp16); /// \brief Get enables to perform the float16 inference /// /// \return Whether enable float16 inference. bool GetEnableFP16() const; /// \brief Set enables to sharing mem with OpenGL /// /// \param[in] is_enable_sharing_mem_with_gl Enable sharing OpenCL Memory with OpenGL or not. void SetEnableGLTexture(bool is_enable_gl_texture); /// \brief Get enables to sharing mem with OpenGL /// /// \return Whether enable sharing mem with OpenGL. bool GetEnableGLTexture() const; /// \brief Set current OpenGL context /// /// \param[in] gl_context Current OpenGL context. void SetGLContext(void *gl_context); /// \brief Get current OpenGL context /// /// \return the OpenCL context by OpenGL used. void *GetGLContext() const; /// \brief Set current OpenGL display /// /// \param[in] gl_display Current OpenGL display. void SetGLDisplay(void *gl_display); /// \brief Get current OpenGL display /// /// \return the OpenCL display by OpenGL used. void *GetGLDisplay() const; private: void SetPrecisionMode(const std::vector &precision_mode); std::vector GetPrecisionModeChar() const; }; void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { SetPrecisionMode(StringToChar(precision_mode)); } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend. This option is /// invalid for MindSpore Lite. class MS_API AscendDeviceInfo : public DeviceInfoContext { public: /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; /// \brief Set device id. /// /// \param[in] device_id The device id. void SetDeviceID(uint32_t device_id); /// \brief Get the device id. /// /// \return The device id. uint32_t GetDeviceID() const; /// \brief Set AIPP configuration file path. /// /// \param[in] cfg_path AIPP configuration file path. inline void SetInsertOpConfigPath(const std::string &cfg_path); /// \brief Get AIPP configuration file path. /// /// \return AIPP configuration file path. inline std::string GetInsertOpConfigPath() const; /// \brief Set format of model inputs. /// /// \param[in] format Optional "NCHW", "NHWC", etc. inline void SetInputFormat(const std::string &format); /// \brief Get format of model inputs. /// /// \return The format of model inputs. inline std::string GetInputFormat() const; /// \brief Set shape of model inputs. /// /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1". inline void SetInputShape(const std::string &shape); /// \brief Get shape of model inputs. /// /// \return The shape of model inputs. inline std::string GetInputShape() const; /// \brief Set shape of model inputs. /// /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input /// shape 4,3,2,1. void SetInputShapeMap(const std::map> &shape); /// \brief Get shape of model inputs. /// /// \return The shape of model inputs. std::map> GetInputShapeMap() const; void SetDynamicBatchSize(const std::vector &dynamic_batch_size); inline std::string GetDynamicBatchSize() const; /// \brief Set the dynamic image size of model inputs. /// /// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64. inline void SetDynamicImageSize(const std::string &dynamic_image_size); /// \brief Get dynamic image size of model inputs. /// /// \return The image size of model inputs. inline std::string GetDynamicImageSize() const; /// \brief Set type of model outputs. /// /// \param[in] output_type FP32, UINT8 or FP16, default as FP32. void SetOutputType(enum DataType output_type); /// \brief Get type of model outputs. /// /// \return The set type of model outputs. enum DataType GetOutputType() const; /// \brief Set precision mode of model. /// /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and /// "allow_mix_precision", "force_fp16" is set as default inline void SetPrecisionMode(const std::string &precision_mode); /// \brief Get precision mode of model. /// /// \return The set type of model outputs inline std::string GetPrecisionMode() const; /// \brief Set op select implementation mode. /// /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as /// default. inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); /// \brief Get op select implementation mode. /// /// \return The set op select implementation mode. inline std::string GetOpSelectImplMode() const; inline void SetFusionSwitchConfigPath(const std::string &cfg_path); inline std::string GetFusionSwitchConfigPath() const; // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); inline std::string GetBufferOptimizeMode() const; private: void SetInsertOpConfigPath(const std::vector &cfg_path); std::vector GetInsertOpConfigPathChar() const; void SetInputFormat(const std::vector &format); std::vector GetInputFormatChar() const; void SetInputShape(const std::vector &shape); std::vector GetInputShapeChar() const; std::vector GetDynamicBatchSizeChar() const; void SetDynamicImageSize(const std::vector &dynamic_image_size); std::vector GetDynamicImageSizeChar() const; void SetPrecisionMode(const std::vector &precision_mode); std::vector GetPrecisionModeChar() const; void SetOpSelectImplMode(const std::vector &op_select_impl_mode); std::vector GetOpSelectImplModeChar() const; void SetFusionSwitchConfigPath(const std::vector &cfg_path); std::vector GetFusionSwitchConfigPathChar() const; void SetBufferOptimizeMode(const std::vector &buffer_optimize_mode); std::vector GetBufferOptimizeModeChar() const; }; using Ascend310DeviceInfo = AscendDeviceInfo; using Ascend910DeviceInfo = AscendDeviceInfo; using Ascend710DeviceInfo = AscendDeviceInfo; void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { SetInsertOpConfigPath(StringToChar(cfg_path)); } std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { SetDynamicImageSize(StringToChar(dynamic_image_size)); } std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { SetPrecisionMode(StringToChar(precision_mode)); } std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { SetOpSelectImplMode(StringToChar(op_select_impl_mode)); } std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { SetFusionSwitchConfigPath(StringToChar(cfg_path)); } std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { return CharToString(GetFusionSwitchConfigPathChar()); } void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); } std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H ================================================ FILE: tests/ut/stub/include/api/data_type.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_ #define MINDSPORE_INCLUDE_API_DATA_TYPE_H_ #include namespace mindspore { enum class DataType : int { kTypeUnknown = 0, kObjectTypeString = 12, kObjectTypeList = 13, kObjectTypeTuple = 14, kObjectTypeTensorType = 17, kNumberTypeBegin = 29, kNumberTypeBool = 30, kNumberTypeInt8 = 32, kNumberTypeInt16 = 33, kNumberTypeInt32 = 34, kNumberTypeInt64 = 35, kNumberTypeUInt8 = 37, kNumberTypeUInt16 = 38, kNumberTypeUInt32 = 39, kNumberTypeUInt64 = 40, kNumberTypeFloat16 = 42, kNumberTypeFloat32 = 43, kNumberTypeFloat64 = 44, kNumberTypeEnd = 46, // add new enum here kInvalidType = INT32_MAX, }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_ ================================================ FILE: tests/ut/stub/include/api/delegate.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_DELEGATE_H #define MINDSPORE_INCLUDE_API_DELEGATE_H #include #include #include #include "schema/model_generated.h" #include "include/api/kernel.h" #include "include/api/status.h" namespace mindspore { typedef enum { SCHEMA_INVALID = -1, /**< invalid version */ SCHEMA_CUR, /**< current version for ms model defined in model.fbs*/ SCHEMA_V0, /**< previous version for ms model defined in model_v0.fbs*/ } SchemaVersion; using KernelIter = std::vector::iterator; template class MS_API DelegateModel { public: /// \brief Constructor of MindSpore Lite DelegateModel. DelegateModel(std::vector *kernels, const std::vector &inputs, const std::vector &outputs, const std::map &primitives, SchemaVersion version) : kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives), version_(version) {} /// \brief Destructor of MindSpore Lite DelegateModel. ~DelegateModel() = default; /// \brief Get Primitive of kernel::Kernel. /// /// \param[in] a kernel in DelegateModel kernels vector. /// /// \return The Primitive of The kernel. const T *GetPrimitive(kernel::Kernel *kernel) const { if (primitives_.find(kernel) != primitives_.end()) { return primitives_.at(kernel); } else { return nullptr; } } /// \brief Get the begin iterator of the DelegateModel kernels vector. /// /// \return The begin iterator of the DelegateModel kernels vector. KernelIter BeginKernelIterator() { return kernels_->begin(); } /// \brief Get the end iterator of the DelegateModel kernels vector. /// /// \return The end iterator of the DelegateModel kernels vector. KernelIter EndKernelIterator() { return kernels_->end(); } /// \brief Replace the continuous kernel supported by the delegate with a delegate graph kernel. /// /// \param[in] from Define the begin iterator of continuous kernel supported by the delegate. /// \param[in] end Define the end iterator of continuous kernel supported by the delegate. /// /// \return The next iterator after graph_kernel, point to the next kernel that is not visited. KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel) { size_t insert_index = from - BeginKernelIterator(); if (insert_index >= kernels_->size()) { return BeginKernelIterator(); } kernels_->erase(from, end); kernels_->insert(BeginKernelIterator() + insert_index, graph_kernel); return BeginKernelIterator() + insert_index + 1; } /// \brief Get the input tensors of DelegateModel. /// /// \return The input tensor vector of DelegateModel. const std::vector &inputs() { return this->inputs_; } /// \brief Get the output tensors of DelegateModel. /// /// \return The ioutput tensor vector of DelegateModel. const std::vector &outputs() { return this->outputs_; } /// \brief Get the ms model version. /// /// \return The schema version for the primitives map. SchemaVersion GetVersion() const { return version_; } protected: std::vector *kernels_; const std::vector &inputs_; const std::vector &outputs_; const std::map &primitives_; SchemaVersion version_; }; class MS_API Delegate { public: /// \brief Constructor of MindSpore Lite Delegate. Delegate() = default; /// \brief Destructor of MindSpore Lite Delegate. virtual ~Delegate() = default; /// \brief Init delegate. /// /// \note Init willed be called in Model::Build. /// /// \return Status. If Status is kLiteNotSupport, the program will return to the MindSpore Lite inner inference. virtual Status Init() = 0; /// \brief Build delegate graph for MindSpore Lite model. /// /// \note Build willed be called in Model::Build. /// /// \param[in] model Define the delegate model to be built. virtual Status Build(DelegateModel *model) = 0; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_DELEGATE_H ================================================ FILE: tests/ut/stub/include/api/dual_abi_helper.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ #define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ #include #include #include #include #include #include #include #include #include namespace mindspore { inline std::vector StringToChar(const std::string &s) { return std::vector(s.begin(), s.end()); } inline std::string CharToString(const std::vector &c) { return std::string(c.begin(), c.end()); } inline std::pair, int32_t> PairStringToChar(const std::pair &s) { return std::pair, int32_t>(std::vector(s.first.begin(), s.first.end()), s.second); } inline std::pair PairCharToString(const std::pair, int32_t> &c) { return std::pair(std::string(c.first.begin(), c.first.end()), c.second); } inline std::vector> VectorStringToChar(const std::vector &s) { std::vector> ret; std::transform(s.begin(), s.end(), std::back_inserter(ret), [](auto str) { return std::vector(str.begin(), str.end()); }); return ret; } inline std::vector VectorCharToString(const std::vector> &c) { std::vector ret; std::transform(c.begin(), c.end(), std::back_inserter(ret), [](auto ch) { return std::string(ch.begin(), ch.end()); }); return ret; } inline std::set> SetStringToChar(const std::set &s) { std::set> ret; std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) { return std::vector(str.begin(), str.end()); }); return ret; } inline std::set SetCharToString(const std::set> &c) { std::set ret; std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) { return std::string(ch.begin(), ch.end()); }); return ret; } inline std::map, int32_t> MapStringToChar(const std::map &s) { std::map, int32_t> ret; std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) { return std::pair, int32_t>(std::vector(str.first.begin(), str.first.end()), str.second); }); return ret; } inline std::map MapCharToString(const std::map, int32_t> &c) { std::map ret; std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) { return std::pair(std::string(ch.first.begin(), ch.first.end()), ch.second); }); return ret; } inline std::map, std::vector> UnorderedMapStringToChar( const std::unordered_map &s) { std::map, std::vector> ret; std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) { return std::pair, std::vector>(std::vector(str.first.begin(), str.first.end()), std::vector(str.second.begin(), str.second.end())); }); return ret; } inline std::unordered_map UnorderedMapCharToString( const std::map, std::vector> &c) { std::unordered_map ret; std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) { return std::pair(std::string(ch.first.begin(), ch.first.end()), std::string(ch.second.begin(), ch.second.end())); }); return ret; } inline std::map, std::vector> MapStringToVectorChar( const std::map &s) { std::map, std::vector> ret; std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) { return std::pair, std::vector>(std::vector(str.first.begin(), str.first.end()), std::vector(str.second.begin(), str.second.end())); }); return ret; } inline std::map MapVectorCharToString( const std::map, std::vector> &c) { std::map ret; std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) { return std::pair(std::string(ch.first.begin(), ch.first.end()), std::string(ch.second.begin(), ch.second.end())); }); return ret; } inline std::vector, std::vector>> ClassIndexStringToChar( const std::vector>> &s) { std::vector, std::vector>> ret; std::transform(s.begin(), s.end(), std::back_inserter(ret), [](auto str) { return std::pair, std::vector>(std::vector(str.first.begin(), str.first.end()), str.second); }); return ret; } inline std::vector>> ClassIndexCharToString( const std::vector, std::vector>> &c) { std::vector>> ret; std::transform(c.begin(), c.end(), std::back_inserter(ret), [](auto ch) { return std::pair>(std::string(ch.first.begin(), ch.first.end()), ch.second); }); return ret; } inline std::vector, int64_t>> PairStringInt64ToPairCharInt64( const std::vector> &s) { std::vector, int64_t>> ret; std::transform(s.begin(), s.end(), std::back_inserter(ret), [](auto str) { return std::pair, int64_t>(std::vector(str.first.begin(), str.first.end()), str.second); }); return ret; } template inline std::map, T> PadInfoStringToChar(const std::map &s_pad_info) { std::map, T> ret; std::transform(s_pad_info.begin(), s_pad_info.end(), std::inserter(ret, ret.begin()), [](auto str) { return std::pair, T>(std::vector(str.first.begin(), str.first.end()), str.second); }); return ret; } template inline std::map PadInfoCharToString(const std::map, T> &c_pad_info) { std::map ret; std::transform(c_pad_info.begin(), c_pad_info.end(), std::inserter(ret, ret.begin()), [](auto ch) { return std::pair(std::string(ch.first.begin(), ch.first.end()), ch.second); }); return ret; } template inline void TensorMapCharToString(const std::map, T> *c, std::unordered_map *s) { if (c == nullptr || s == nullptr) { return; } for (auto ch : *c) { auto key = std::string(ch.first.begin(), ch.first.end()); auto val = ch.second; s->insert(std::pair(key, val)); } } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ ================================================ FILE: tests/ut/stub/include/api/format.h ================================================ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_FORMAT_H #define MINDSPORE_INCLUDE_API_FORMAT_H #if __has_include("include/mindapi/base/format.h") #include "include/mindapi/base/format.h" #else #include "mindapi/base/format.h" #endif #endif // MINDSPORE_INCLUDE_API_FORMAT_H ================================================ FILE: tests/ut/stub/include/api/graph.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_GRAPH_H #define MINDSPORE_INCLUDE_API_GRAPH_H #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" namespace mindspore { class MS_API Graph { public: class GraphData; Graph(); explicit Graph(const std::shared_ptr &graph_data); explicit Graph(std::shared_ptr &&graph_data); explicit Graph(std::nullptr_t); ~Graph(); enum ModelType ModelType() const; bool operator==(std::nullptr_t) const; bool operator!=(std::nullptr_t) const; private: friend class GraphCell; friend class ModelImpl; friend class GraphImplStubAdd; std::shared_ptr graph_data_; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_GRAPH_H ================================================ FILE: tests/ut/stub/include/api/kernel.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_KERNEL_H #define MINDSPORE_INCLUDE_API_KERNEL_H #include #include #include #include #include "schema/model_generated.h" #include "include/api/types.h" #include "include/api/context.h" namespace mindspore::kernel { /// \brief The Kernel class is used to define a MindSpore Kernel. class MS_API Kernel { public: Kernel() = default; /// \brief Constructor. /// /// \param[in] inputs define the input tensors for kernel. /// \param[in] outputs define the output tensors for kernel. /// \param[in] primitive define the primitive of kernel generated by flatbuffers. /// \param[in] ctx define the context for kernel. Kernel(const std::vector &inputs, const std::vector &outputs, const schema::Primitive *primitive, const mindspore::Context *ctx) : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) { Initialize(); } /// \brief Destructor. virtual ~Kernel() = default; /// \brief prepare for executing kernel. /// /// \return result code. virtual int Prepare() = 0; /// \brief execute the kernel. /// /// \return result code. virtual int Execute() = 0; /// \brief resize the kernel input shape, memory need to refresh. /// /// \return result code. virtual int ReSize() = 0; /// \brief set kernel's input tensors. /// /// \param[in] in_tensors define the input tensors. virtual void set_inputs(const std::vector &in_tensors) { this->inputs_ = in_tensors; } /// \brief set kernel's input tensor. /// /// \param[in] in_tensor define the input tensor. /// \param[in] index define the index of the input tensor. virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; } /// \brief set kernel's output tensors. /// /// \param[in] out_tensors define the output tensors. virtual void set_outputs(const std::vector &out_tensors) { this->outputs_ = out_tensors; } /// \brief set kernel's output tensor. /// /// \param[in] out_tensor define the output tensor. /// \param[in] index define the index of the output tensor. virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; } /// \brief obtain kernel's input tensors. /// /// \return input tensors. virtual const std::vector &inputs() { return this->inputs_; } /// \brief obtain kernel's output tensors. /// /// \return output tensors. virtual const std::vector &outputs() { return this->outputs_; } /// \brief obtain kernel's name. /// /// \return kernel's name. std::string name() const { return this->name_; } /// \brief set kernel's name. /// /// \param[in] name define the kernel's name. void set_name(const std::string &name) { this->name_ = name; } /// \brief obtain kernel's context. /// /// \return kernel's context. const mindspore::Context *context() const { return this->context_; } /// \brief obtain kernel's type. /// /// \return kernel's type. virtual schema::PrimitiveType type() const { return type_; } /// \brief obtain kernel's quant type. /// /// \return kernel's quant type. virtual schema::QuantType quant_type() const { return quant_type_; } /// \brief obtain the primitive of kernel generated by flatbuffers. /// /// \return the primitive of kernel generated by flatbuffers. const schema::Primitive *primitive() const { return this->primitive_; } /// \brief get kernel's attribute. /// /// \param[in] key define the kernel's attribute key. std::string GetAttr(const std::string &key) const { auto iter = attrs_.find(key); if (iter != attrs_.end()) { return iter->second; } return ""; } /// \brief set kernel's config. /// /// \param[in] config define the kernel's config. void SetConfig(const std::map> *config) { config_ = config; } /// \brief set kernel's config. /// /// \param[in] config define the kernel's config. std::map GetConfig(const std::string §ion) const { if (config_ == nullptr) { return std::map(); } auto iter = config_->find(section); if (iter != config_->end()) { return iter->second; } return std::map(); } protected: /// \brief set kernel's attribute /// /// \param[in] key define the kernel's attribute key. /// \param[in] value define the kernel's attribute value. void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } std::string name_; const mindspore::Context *context_ = nullptr; std::vector inputs_; std::vector outputs_; schema::PrimitiveType type_ = schema::PrimitiveType_NONE; const schema::Primitive *primitive_ = nullptr; std::map attrs_; const std::map> *config_; schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; private: void Initialize(); }; } // namespace mindspore::kernel #endif // MINDSPORE_INCLUDE_API_KERNEL_H ================================================ FILE: tests/ut/stub/include/api/metrics/accuracy.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H #define MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H #include #include "include/api/metrics/metrics.h" namespace mindspore { constexpr int METRICS_CLASSIFICATION = 0; constexpr int METRICS_MULTILABEL = 1; class AccuracyMetrics : public Metrics { public: explicit AccuracyMetrics(int accuracy_metrics = METRICS_CLASSIFICATION, const std::vector &input_indexes = {1}, const std::vector &output_indexes = {0}); virtual ~AccuracyMetrics(); void Clear() override; float Eval() override; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_METRICS_ACCURACY_H ================================================ FILE: tests/ut/stub/include/api/metrics/metrics.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_METRICS_METRICS_H #define MINDSPORE_INCLUDE_API_METRICS_METRICS_H #include #include "include/api/model.h" namespace mindspore { class MetricsImpl; class ModelImpl; class MSTensor; class Metrics { public: virtual ~Metrics() = default; virtual void Clear() {} virtual float Eval() { return 0.0; } virtual void Update(std::vector inputs, std::vector outputs) {} protected: friend class Model; friend class ModelImpl; MetricsImpl* metrics_impl_; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_METRICS_METRICS_H ================================================ FILE: tests/ut/stub/include/api/model.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_MODEL_H #define MINDSPORE_INCLUDE_API_MODEL_H #include #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" #include "include/api/graph.h" #include "include/api/context.h" #include "include/api/callback/callback.h" #include "include/api/cell.h" #include "include/api/cfg.h" #include "include/api/dual_abi_helper.h" namespace mindspore { class ModelImpl; class Metrics; namespace dataset { class Dataset; } // namespace dataset /// \brief The Model class is used to define a MindSpore model, facilitating computational graph management. class MS_API Model { public: Model(); ~Model(); Model(const Model &) = delete; void operator=(const Model &) = delete; /// \brief Builds a model /// /// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed /// from Graph, for example, model.Build(GraphCell(graph), context). /// \param[in] model_context A context used to store options during execution. /// \param[in] train_cfg A config used by training. /// /// \return Status. Status Build(GraphCell graph, const std::shared_ptr &model_context = nullptr, const std::shared_ptr &train_cfg = nullptr); /// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable /// /// \param[in] backbone The static, non-learnable part of the graph /// \param[in] head The trainable part of the graph /// \param[in] context A context used to store options during execution /// \param[in] cfg A config used by training /// /// \return Status Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr &context, const std::shared_ptr &train_cfg = nullptr); /// \brief Resizes the shapes of inputs. /// /// \param[in] inputs A vector that includes all input tensors in order. /// \param[in] dims Defines the new shapes of inputs, should be consistent with inputs. /// /// \return Status. Status Resize(const std::vector &inputs, const std::vector> &dims); /// \brief Change the size and or content of weight tensors /// /// \param[in] new_weights a vector of tensors with new shapes and data to use in the model /// If data pointer is null, the data of the original tensors will be copied to the new ones /// /// \return Status. Status UpdateWeights(const std::vector &new_weights); /// \brief Inference model. /// /// \param[in] inputs A vector where model inputs are arranged in sequence. /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. /// \param[in] before CallBack before predict. /// \param[in] after CallBack after predict. /// /// \return Status. Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); /// \brief Train model by step. /// /// \param[in] before CallBack before predict. /// \param[in] after CallBack after predict. /// /// \return Status. Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); /// \brief Inference model with preprocess in model. /// /// \param[in] inputs A vector where model inputs are arranged in sequence. /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. /// \param[in] whether to use data preprocess in model. /// \param[in] before CallBack before predict. /// \param[in] after CallBack after predict. /// /// \return Status. Status PredictWithPreprocess(const std::vector> &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); /// \brief Apply data preprocess if it exits in model. /// /// \param[in] inputs A vector where model inputs are arranged in sequence. /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. /// /// \return Status. Status Preprocess(const std::vector> &inputs, std::vector *outputs); /// \brief Check if data preprocess exists in model. /// \return true if data preprocess exists. bool HasPreprocess(); /// \brief Load config file. /// /// \param[in] config_path config file path. /// /// \return Status. inline Status LoadConfig(const std::string &config_path); /// \brief Update config. /// /// \param[in] section define the config section. /// \param[in] config define the config will be updated. /// /// \return Status. inline Status UpdateConfig(const std::string §ion, const std::pair &config); /// \brief Obtains all input tensors of the model. /// /// \return The vector that includes all input tensors. std::vector GetInputs(); /// \brief Obtains the input tensor of the model by name. /// /// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned. inline MSTensor GetInputByTensorName(const std::string &tensor_name); /// \brief Obtains all gradient tensors of the model. /// /// \return The vector that includes all gradient tensors. std::vector GetGradients() const; /// \brief update gradient tensors of the model. /// /// \param[in] inputs A vector new gradients. /// \return Status of operation Status ApplyGradients(const std::vector &gradients); /// \brief Obtains all weights tensors of the model. /// /// \return The vector that includes all gradient tensors. std::vector GetFeatureMaps() const; /// \brief update weights tensors of the model. /// /// \param[in] inputs A vector new weights. /// \return Status of operation Status UpdateFeatureMaps(const std::vector &new_weights); /// \brief Obtains optimizer params tensors of the model. /// /// \return The vector that includes all params tensors. std::vector GetOptimizerParams() const; /// \brief update the optimizer parameters /// /// \param[in] inputs A vector new optimizer params. /// \return Status of operation Status SetOptimizerParams(const std::vector ¶ms); /// \brief Setup training with virtual batches /// /// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable /// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration /// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration /// \return Status of operation Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f); /// \brief Sets the Learning Rate of the training /// /// \param[in] learning_rate to set /// \return Status of operation Status SetLearningRate(float learning_rate); /// \brief Gets the Learning Rate of the optimizer /// /// \return learning rate. 0.0 if no optimizer was found float GetLearningRate(); Status InitMetrics(std::vector metrics); std::vector GetMetrics(); /// \brief Obtains all output tensors of the model. /// /// \return The vector that includes all output tensors. std::vector GetOutputs(); /// \brief Obtains names of all output tensors of the model. /// /// \return A vector that includes names of all output tensors. inline std::vector GetOutputTensorNames(); /// \brief Obtains the output tensor of the model by name. /// /// \return The output tensor with the given name, if the name is not found, an invalid tensor is returned. inline MSTensor GetOutputByTensorName(const std::string &tensor_name); /// \brief Get output MSTensors of model by node name. /// /// \param[in] node_name Define node name. /// /// \note Deprecated, replace with GetOutputByTensorName /// /// \return The vector of output MSTensor. inline std::vector GetOutputsByNodeName(const std::string &node_name); /// \brief Bind GLTexture2D object to cl Memory. /// /// \param[in] inputGlTexture The input GLTexture id for Model. /// \param[in] outputGLTexture The output GLTexture id for Model. /// /// \return Status of operation. Status BindGLTexture2DMemory(const std::map &inputGLTexture, std::map *outputGLTexture); /// \brief Inference model. /// /// \param[in] device_type Device type,options are kGPU, kAscend etc. /// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM. /// /// \return Is supported or not. static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type); Status SetTrainMode(bool train); bool GetTrainMode() const; Status Train(int epochs, std::shared_ptr ds, std::vector cbs); Status Evaluate(std::shared_ptr ds, std::vector cbs); /// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite. /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. /// /// \return Status. Status Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context = nullptr); /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// /// \param[in] model_path Define the model path. /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. /// /// \return Status. Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context = nullptr); /// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite. /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16. /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM. /// \param[in] cropto_lib_path Define the openssl library path. /// /// \return Status. Status Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path); /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// /// \param[in] model_path Define the model path. /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16. /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM. /// \param[in] cropto_lib_path Define the openssl library path. /// /// \return Status. Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path); private: friend class Serialization; // api without std::string MSTensor GetInputByTensorName(const std::vector &tensor_name); std::vector> GetOutputTensorNamesChar(); MSTensor GetOutputByTensorName(const std::vector &tensor_name); std::vector GetOutputsByNodeName(const std::vector &node_name); Status LoadConfig(const std::vector &config_path); Status UpdateConfig(const std::vector §ion, const std::pair, std::vector> &config); Status Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context); Status Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::vector &cropto_lib_path); std::shared_ptr impl_; }; MSTensor Model::GetInputByTensorName(const std::string &tensor_name) { return GetInputByTensorName(StringToChar(tensor_name)); } std::vector Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); } MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) { return GetOutputByTensorName(StringToChar(tensor_name)); } std::vector Model::GetOutputsByNodeName(const std::string &node_name) { return GetOutputsByNodeName(StringToChar(node_name)); } Status Model::LoadConfig(const std::string &config_path) { return LoadConfig(StringToChar(config_path)); } Status Model::UpdateConfig(const std::string §ion, const std::pair &config) { std::pair, std::vector> config_pair = {StringToChar(config.first), StringToChar(config.second)}; return UpdateConfig(StringToChar(section), config_pair); } inline Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path) { return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path)); } inline Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context) { return Build(StringToChar(model_path), model_type, model_context); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H ================================================ FILE: tests/ut/stub/include/api/model_parallel_runner.h ================================================ /** * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H #define MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H #include #include #include #include #include "include/api/status.h" #include "include/api/context.h" namespace mindspore { struct RunnerConfig { std::shared_ptr context = nullptr; int workers_num = 0; }; /// \brief The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model /// management. class MS_API ModelParallelRunner { public: ModelParallelRunner() = default; ~ModelParallelRunner() = default; /// \brief build a model parallel runner from model path so that it can run on a device. Only valid for Lite. /// /// \param[in] model_path Define the model path. /// \param[in] runner_config Define the config used to store options during model pool init. /// /// \return Status. Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr); /// \brief Obtains all input tensors information of the model. /// /// \return The vector that includes all input tensors. std::vector GetInputs(); /// \brief Obtains all output tensors information of the model. /// /// \return The vector that includes all output tensors. std::vector GetOutputs(); /// \brief Inference ModelParallelRunner. /// /// \param[in] inputs A vector where model inputs are arranged in sequence. /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. /// \param[in] before CallBack before predict. /// \param[in] after CallBack after predict. /// /// \return Status. Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H ================================================ FILE: tests/ut/stub/include/api/ops/ops.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_OPS_OPS_H #define MINDSPORE_INCLUDE_API_OPS_OPS_H #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" #include "include/api/cell.h" namespace mindspore { struct MS_API Conv2D : public OpCell { Conv2D() : OpCell("Conv2D") {} ~Conv2D() override = default; std::vector Construct(const std::vector &inputs) override; Conv2D(int out_channel, const std::vector &kernel_size, int mode = 1, const std::string &pad_mode = "valid", const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1, 1, 1}, const std::vector &dilation = {1, 1, 1, 1}, int group = 1); Output operator()(const Input &, const Input &) const; int out_channel; std::vector kernel_size; int mode = 1; std::string pad_mode = "valid"; std::vector pad = {0, 0, 0, 0}; std::vector stride = {1, 1, 1, 1}; std::vector dilation = {1, 1, 1, 1}; int group = 1; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_OPS_OPS_H ================================================ FILE: tests/ut/stub/include/api/serialization.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H #define MINDSPORE_INCLUDE_API_SERIALIZATION_H #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" #include "include/api/model.h" #include "include/api/graph.h" #include "include/api/dual_abi_helper.h" namespace mindspore { /// \brief The Serialization class is used to summarize methods for reading and writing model files. class MS_API Serialization { public: /// \brief Loads a model file from memory buffer. /// /// \param[in] model_data A buffer filled by model file. /// \param[in] data_size The size of the buffer. /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM. /// \param[out] graph The output parameter, an object saves graph data. /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. /// /// \return Status. inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); /// \brief Loads a model file from path, is not supported on MindSpore Lite. /// /// \param[in] file The path of model file. /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM. /// \param[out] graph The output parameter, an object saves graph data. /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. /// /// \return Status. inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); /// \brief Load multiple models from multiple files, MindSpore Lite does not provide this feature. /// /// \param[in] files The path of model files. /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM. /// \param[out] graph The output parameter, an object saves graph data. /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. /// /// \return Status. inline static Status Load(const std::vector &files, ModelType model_type, std::vector *graphs, const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); static Status SetParameters(const std::map ¶meters, Model *model); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file, QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, std::vector output_tensor_name = {}); private: static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, const std::vector &dec_mode); static Status Load(const std::vector &file, ModelType model_type, Graph *graph); static Status Load(const std::vector &file, ModelType model_type, Graph *graph, const Key &dec_key, const std::vector &dec_mode); static Status Load(const std::vector> &files, ModelType model_type, std::vector *graphs, const Key &dec_key, const std::vector &dec_mode); static Status ExportModel(const Model &model, ModelType model_type, const std::vector &model_file, QuantizationType quantization_type, bool export_inference_only, const std::vector> &output_tensor_name); }; Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, const std::string &dec_mode) { return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode)); } Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key, const std::string &dec_mode) { return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode)); } Status Serialization::Load(const std::vector &files, ModelType model_type, std::vector *graphs, const Key &dec_key, const std::string &dec_mode) { return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode)); } Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, QuantizationType quantization_type, bool export_inference_only, std::vector output_tensor_name) { return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only, VectorStringToChar(output_tensor_name)); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H ================================================ FILE: tests/ut/stub/include/api/status.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H #include #include #include #include #include #include "include/api/dual_abi_helper.h" #include "include/api/types.h" namespace mindspore { enum CompCode : uint32_t { kCore = 0x00000000u, kMD = 0x10000000u, kME = 0x20000000u, kMC = 0x30000000u, kLite = 0xF0000000u, }; enum StatusCode : uint32_t { kSuccess = 0, // Core kCoreFailed = kCore | 0x1, // MD kMDOutOfMemory = kMD | 1, kMDShapeMisMatch = kMD | 2, kMDInterrupted = kMD | 3, kMDNoSpace = kMD | 4, kMDPyFuncException = kMD | 5, kMDDuplicateKey = kMD | 6, kMDPythonInterpreterFailure = kMD | 7, kMDTDTPushFailure = kMD | 8, kMDFileNotExist = kMD | 9, kMDProfilingError = kMD | 10, kMDBoundingBoxOutOfBounds = kMD | 11, kMDBoundingBoxInvalidShape = kMD | 12, kMDSyntaxError = kMD | 13, kMDTimeOut = kMD | 14, kMDBuddySpaceFull = kMD | 15, kMDNetWorkError = kMD | 16, kMDNotImplementedYet = kMD | 17, // Make this error code the last one. Add new error code above it. kMDUnexpectedError = kMD | 127, // ME kMEFailed = kME | 0x1, kMEInvalidInput = kME | 0x2, // MC kMCFailed = kMC | 0x1, kMCDeviceError = kMC | 0x2, kMCInvalidInput = kMC | 0x3, kMCInvalidArgs = kMC | 0x4, // Lite // Common error code, range: [-1, -100) kLiteError = kLite | (0x0FFFFFFF & -1), /**< Common error code. */ kLiteNullptr = kLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/ kLiteParamInvalid = kLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/ kLiteNoChange = kLite | (0x0FFFFFFF & -4), /**< No change. */ kLiteSuccessExit = kLite | (0x0FFFFFFF & -5), /**< No error but exit. */ kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */ kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */ kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */ kLiteUninitializedObj = kLite | (0x0FFFFFFF & -9), /**< Object is not initialized. */ kLiteFileError = kLite | (0x0FFFFFFF & -10), /**< Invalid file. */ // Executor error code, range: [-100,-200) kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */ kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */ kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */ // Graph error code, range: [-200,-300) kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */ // Node error code, range: [-300,-400) kLiteNotFindOp = kLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */ kLiteInvalidOpName = kLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */ kLiteInvalidOpAttr = kLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */ kLiteOpExecuteFailure = kLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */ // Tensor error code, range: [-400,-500) kLiteFormatError = kLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */ // InferShape error code, range: [-500,-600) kLiteInferError = kLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */ kLiteInferInvalid = kLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */ // User input param error code, range: [-600, 700) kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */ }; class MS_API Status { public: Status(); inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit) inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); ~Status() = default; enum StatusCode StatusCode() const; inline std::string ToString() const; int GetLineOfCode() const; inline std::string GetErrDescription() const; inline std::string SetErrDescription(const std::string &err_description); MS_API friend std::ostream &operator<<(std::ostream &os, const Status &s); bool operator==(const Status &other) const; bool operator==(enum StatusCode other_code) const; bool operator!=(const Status &other) const; bool operator!=(enum StatusCode other_code) const; explicit operator bool() const; explicit operator int() const; static Status OK(); bool IsOk() const; bool IsError() const; static inline std::string CodeAsString(enum StatusCode c); private: // api without std::string Status(enum StatusCode status_code, const std::vector &status_msg); Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra); std::vector ToCString() const; std::vector GetErrDescriptionChar() const; std::vector SetErrDescription(const std::vector &err_description); static std::vector CodeAsCString(enum StatusCode c); struct Data; std::shared_ptr data_; }; Status::Status(enum StatusCode status_code, const std::string &status_msg) : Status(status_code, StringToChar(status_msg)) {} Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) : Status(code, line_of_code, file_name, StringToChar(extra)) {} std::string Status::ToString() const { return CharToString(ToCString()); } std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); } std::string Status::SetErrDescription(const std::string &err_description) { return CharToString(SetErrDescription(StringToChar(err_description))); } std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_STATUS_H ================================================ FILE: tests/ut/stub/include/api/types.h ================================================ /** * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_TYPES_H #define MINDSPORE_INCLUDE_API_TYPES_H #include #include #include #include #include #include "include/api/data_type.h" #include "include/api/dual_abi_helper.h" #include "include/api/format.h" #include "include/api/visible.h" namespace mindspore { enum ModelType : uint32_t { kMindIR = 0, kAIR = 1, kOM = 2, kONNX = 3, kMindIR_Lite = 4, // insert new data type here kUnknownType = 0xFFFFFFFF }; enum QuantizationType : uint32_t { kNoQuant = 0, kWeightQuant = 1, kFullQuant = 2, kUnknownQuantType = 0xFFFFFFFF }; enum OptimizationLevel : uint32_t { kO0 = 0, // Do not change kO2 = 2, // Cast network to float16, keep batchnorm and loss in float32, kO3 = 3, // Cast network to float16, including bacthnorm kAuto = 4, // Choose optimization based on device kOptimizationType = 0xFFFFFFFF }; struct QuantParam { int bit_num; double scale; int32_t zero_point; double min; double max; }; class Allocator; /// \brief The MSTensor class defines a tensor in MindSpore. class MS_API MSTensor { public: class Impl; /// \brief Creates a MSTensor object, whose data need to be copied before accessed by Model, must be used in pairs /// with DestroyTensorPtr. /// /// \param[in] name The name of the MSTensor. /// \param[in] type The data type of the MSTensor. /// \param[in] shape The shape of the MSTensor. /// \param[in] data The data pointer that points to allocated memory. /// \param[in] data_len The length of the memory, in bytes. /// /// \return A pointer of MSTensor. static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept; /// \brief Creates a MSTensor object, whose data can be directly accessed by Model, must be used in pairs with /// DestroyTensorPtr. /// /// \param[in] name The name of the MSTensor. /// \param[in] type The data type of the MSTensor. /// \param[in] shape The shape of the MSTensor. /// \param[in] data The data pointer that points to allocated memory. /// \param[in] data_len The length of the memory, in bytes. /// \param[in] own_data Whether the data memory should be freed in MSTensor destruction. /// /// \return A pointer of MSTensor. static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, size_t data_len, bool own_data = true) noexcept; /// \brief Creates a MSTensor object, whose device data can be directly accessed by Model, must be used in pairs with /// DestroyTensorPtr. /// /// \param[in] name The name of the MSTensor. /// \param[in] type The data type of the MSTensor. /// \param[in] shape The shape of the MSTensor. /// \param[in] data The data pointer that points to device memory. /// \param[in] data_len The length of the memory, in bytes. /// /// \return A pointer of MSTensor. static inline MSTensor CreateDeviceTensor(const std::string &name, DataType type, const std::vector &shape, void *data, size_t data_len) noexcept; /// \brief Creates a MSTensor object from local file, must be used in pairs with DestroyTensorPtr. /// /// \param[in] file Path of file to be read. /// \param[in] type The data type of the MSTensor. /// \param[in] shape The shape of the MSTensor. /// /// \return A pointer of MSTensor. static inline MSTensor *CreateTensorFromFile(const std::string &file, DataType type = DataType::kNumberTypeUInt8, const std::vector &shape = {}) noexcept; /// \brief Create a string type MSTensor object whose data can be accessed by Model only after being copied, must be /// used in pair with DestroyTensorPtr. /// /// \param[in] name The name of the MSTensor. /// \param[in] str A vector container containing several strings. /// /// \return A pointer of MSTensor. static inline MSTensor *StringsToTensor(const std::string &name, const std::vector &str); /// \brief Parse the string type MSTensor object into strings. /// /// \param[in] tensor A MSTensor object. /// /// \return A vector container containing several strings. static inline std::vector TensorToStrings(const MSTensor &tensor); /// \brief Destroy an object created by Clone, StringsToTensor, CreateRefTensor or CreateTensor. Do /// not use it to destroy MSTensor from other sources. /// /// \param[in] tensor A MSTensor object. static void DestroyTensorPtr(MSTensor *tensor) noexcept; MSTensor(); explicit MSTensor(const std::shared_ptr &impl); // if malloc data, user need to free after constructing MSTensor, else memory leak. inline MSTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, size_t data_len); explicit MSTensor(std::nullptr_t); ~MSTensor(); /// \brief Obtains the name of the MSTensor. /// /// \return The name of the MSTensor. inline std::string Name() const; /// \brief Obtains the data type of the MSTensor. /// /// \return The data type of the MSTensor. enum DataType DataType() const; /// \brief Obtains the shape of the MSTensor. /// /// \return The shape of the MSTensor. const std::vector &Shape() const; /// \brief Obtains the number of elements of the MSTensor. /// /// \return The number of elements of the MSTensor. int64_t ElementNum() const; /// \brief Obtains a shared pointer to the copy of data of the MSTensor. The data can be read on host. /// /// \return A shared pointer to the copy of data of the MSTensor. std::shared_ptr Data() const; /// \brief Obtains the pointer to the data of the MSTensor. If the MSTensor is a device tensor, the data cannot be /// accessed directly on host. /// /// \return A pointer to the data of the MSTensor. void *MutableData(); /// \brief Obtains the length of the data of the MSTensor, in bytes. /// /// \return The length of the data of the MSTensor, in bytes. size_t DataSize() const; /// \brief Get whether the MSTensor data is const data /// /// \return Const flag of MSTensor bool IsConst() const; /// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device. /// /// \return The boolean value that indicates whether the memory of MSTensor is on device. bool IsDevice() const; /// \brief Gets a deep copy of the MSTensor, must be used in pair with DestroyTensorPtr. /// /// \return A pointer points to a deep copy of the MSTensor. MSTensor *Clone() const; /// \brief Gets the boolean value that indicates whether the MSTensor is valid. /// /// \return The boolean value that indicates whether the MSTensor is valid. bool operator==(std::nullptr_t) const; /// \brief Gets the boolean value that indicates whether the MSTensor is valid. /// /// \return The boolean value that indicates whether the MSTensor is valid. bool operator!=(std::nullptr_t) const; /// \brief Get the boolean value that indicates whether the MSTensor equals tensor. /// /// \param[in] another MSTensor. /// /// \return The boolean value that indicates whether the MSTensor equals tensor. bool operator==(const MSTensor &tensor) const; /// \brief Get the boolean value that indicates whether the MSTensor not equals tensor. /// /// \param[in] another MSTensor. /// /// \return The boolean value that indicates whether the MSTensor not equals tensor. bool operator!=(const MSTensor &tensor) const; /// \brief Set the shape of for the MSTensor. Only valid for Lite. /// /// \param[in] shape Shape of the MSTensor, a vector of int64_t. void SetShape(const std::vector &shape); /// \brief Set the data type for the MSTensor. Only valid for Lite. /// /// \param[in] data_type The data type of the MSTensor. void SetDataType(enum DataType data_type); /// \brief Set the name for the MSTensor. Only valid for Lite. /// /// \param[in] name The name of the MSTensor. inline void SetTensorName(const std::string &name); /// \brief Set the Allocator for the MSTensor. Only valid for Lite. /// /// \param[in] allocator A pointer to Allocator. void SetAllocator(std::shared_ptr allocator); /// \brief Obtain the Allocator of the MSTensor. Only valid for Lite. /// /// \return A pointer to Allocator. std::shared_ptr allocator() const; /// \brief Set the format for the MSTensor. Only valid for Lite. /// /// \param[in] format The format of the MSTensor. void SetFormat(mindspore::Format format); /// \brief Obtain the format of the MSTensor. Only valid for Lite. /// /// \return The format of the MSTensor. mindspore::Format format() const; /// \brief Set the data for the MSTensor. Only valid for Lite. /// /// \note Deprecated, this interface will be removed in the next iteration /// /// \note A pointer to the data should be created by malloc interface /// /// \note The memory pointed to origin data pointer of MSTensor needs to be managed by the user /// /// \param[in] data A pointer to the data of the MSTensor. /// \param[in] own_data Whether the data memory should be freed in MSTensor destruction. void SetData(void *data, bool own_data = true); /// \brief Set the device data address for the MSTensor. Only valid for Lite. /// /// \note The memory pointed to origin data pointer of MSTensor needs to be managed by the user /// /// \param[in] data A pointer to the device data of the MSTensor. void SetDeviceData(void *data); /// \brief Get the device data address of the MSTensor set by SetDeviceData. Only valid for Lite. /// /// \return A pointer to the device data of the MSTensor. void *GetDeviceData(); /// \brief Get the quantization parameters of the MSTensor. Only valid for Lite. /// /// \return The quantization parameters of the MSTensor. std::vector QuantParams() const; /// \brief Set the quantization parameters for the MSTensor. Only valid for Lite. /// /// \param[in] quant_params The quantization parameters of the MSTensor. void SetQuantParams(std::vector quant_params); const std::shared_ptr impl() const { return impl_; } private: // api without std::string static MSTensor *CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept; static MSTensor *CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len, bool own_data) noexcept; static MSTensor CreateDeviceTensor(const std::vector &name, enum DataType type, const std::vector &shape, void *data, size_t data_len) noexcept; static MSTensor *CreateTensorFromFile(const std::vector &file, enum DataType type, const std::vector &shape) noexcept; static MSTensor *CharStringsToTensor(const std::vector &name, const std::vector> &str); static std::vector> TensorToStringChars(const MSTensor &tensor); MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len); std::vector CharName() const; void SetTensorName(const std::vector &name); friend class ModelImpl; std::shared_ptr impl_; }; class MS_API Buffer { public: Buffer(); Buffer(const void *data, size_t data_len); ~Buffer(); const void *Data() const; void *MutableData(); size_t DataSize() const; bool ResizeData(size_t data_len); bool SetData(const void *data, size_t data_len); Buffer Clone() const; private: class Impl; std::shared_ptr impl_; }; MSTensor *MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { return CreateTensor(StringToChar(name), type, shape, data, data_len); } MSTensor *MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len, bool own_data) noexcept { return CreateRefTensor(StringToChar(name), type, shape, data, data_len, own_data); } MSTensor MSTensor::CreateDeviceTensor(const std::string &name, enum DataType type, const std::vector &shape, void *data, size_t data_len) noexcept { return CreateDeviceTensor(StringToChar(name), type, shape, data, data_len); } MSTensor *MSTensor::CreateTensorFromFile(const std::string &file, enum DataType type, const std::vector &shape) noexcept { return CreateTensorFromFile(StringToChar(file), type, shape); } MSTensor *MSTensor::StringsToTensor(const std::string &name, const std::vector &str) { return CharStringsToTensor(StringToChar(name), VectorStringToChar(str)); } std::vector MSTensor::TensorToStrings(const MSTensor &tensor) { return VectorCharToString(TensorToStringChars(tensor)); } MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) : MSTensor(StringToChar(name), type, shape, data, data_len) {} std::string MSTensor::Name() const { return CharToString(CharName()); } void MSTensor::SetTensorName(const std::string &name) { SetTensorName(StringToChar(name)); } using Key = struct Key { const size_t max_key_len = 32; size_t len = 0; unsigned char key[32] = {0}; Key() : len(0) {} explicit Key(const char *dec_key, size_t key_len); }; constexpr char kDecModeAesGcm[] = "AES-GCM"; /// \brief CallBackParam defined input arguments for callBack function. struct MSCallBackParam { std::string node_name; /**< node name argument */ std::string node_type; /**< node type argument */ double execute_time; /**< gpu execute time */ }; /// \brief KernelCallBack defined the function pointer for callBack. using MSKernelCallBack = std::function & /* inputs */, const std::vector & /* outputs */, const MSCallBackParam &opInfo)>; std::vector CharVersion(); inline std::string Version() { return CharToString(CharVersion()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_TYPES_H ================================================ FILE: tests/ut/stub/include/api/visible.h ================================================ /** * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_VISIBLE_H #define MINDSPORE_INCLUDE_API_VISIBLE_H #ifndef MS_API #ifdef _WIN32 #define MS_API __declspec(dllexport) #else #define MS_API __attribute__((visibility("default"))) #endif // _WIN32 #endif #ifdef _MSC_VER #ifdef BUILDING_DATASET_DLL #define DATASET_API __declspec(dllexport) #else #define DATASET_API __declspec(dllimport) #endif #else #define DATASET_API __attribute__((visibility("default"))) #endif // _MSC_VER #endif // MINDSPORE_INCLUDE_API_VISIBLE_H ================================================ FILE: tests/ut/stub/include/mindapi/base/format.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_MINDAPI_BASE_FORMAT_H_ #define MINDSPORE_CORE_MINDAPI_BASE_FORMAT_H_ #include namespace mindspore { enum Format : int64_t { NCHW = 0, NHWC = 1, NHWC4 = 2, HWKC = 3, HWCK = 4, KCHW = 5, CKHW = 6, KHWC = 7, CHWK = 8, HW = 9, HW4 = 10, NC = 11, NC4 = 12, NC4HW4 = 13, NUM_OF_FORMAT = 14, NCDHW = 15, NWC = 16, NCW = 17, }; } // namespace mindspore #endif // MINDSPORE_CORE_MINDAPI_BASE_FORMAT_H_ ================================================ FILE: tests/ut/stub/include/mindapi/base/type_id.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ #define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ namespace mindspore { /// \brief TypeId defines data type identifiers. enum TypeId : int { kTypeUnknown = 0, // // Meta types. // kMetaTypeBegin = kTypeUnknown, kMetaTypeType, // Type kMetaTypeAnything, kMetaTypeObject, kMetaTypeTypeType, // TypeType kMetaTypeProblem, kMetaTypeExternal, kMetaTypeNone, kMetaTypeNull, kMetaTypeEllipsis, kMetaTypeEnd, // // Object types // kObjectTypeBegin = kMetaTypeEnd, kObjectTypeNumber, kObjectTypeString, kObjectTypeList, kObjectTypeTuple, kObjectTypeSlice, kObjectTypeKeyword, kObjectTypeTensorType, kObjectTypeRowTensorType, kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, kObjectTypeClass, kObjectTypeDictionary, kObjectTypeFunction, kObjectTypeJTagged, kObjectTypeSymbolicKeyType, kObjectTypeEnvType, kObjectTypeRefKey, kObjectTypeRef, kObjectTypeEnd, // // Number Types // kNumberTypeBegin = kObjectTypeEnd, kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex, kNumberTypeComplex64, kNumberTypeComplex128, kNumberTypeInt4, kNumberTypeGLUInt, kNumberTypeEnd, // // Monad Types // kMonadTypeBegin = kNumberTypeEnd, kObjectTypeMonad, kObjectTypeUMonad, kObjectTypeIOMonad, kMonadTypeEnd, // // Sparse Types // // Sparse types is placed at the end of enum, // in order to keep fit with the type of existing model on the lite side. kSparseTypeBegin = kMonadTypeEnd, kObjectTypeCSRTensorType, kSparseTypeEnd }; } // namespace mindspore #endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ ================================================ FILE: tests/ut/stub/include/mindapi/base/types.h ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_ #define MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_ #include namespace mindspore { enum CoordinateTransformMode : int64_t { ASYMMETRIC = 0, ALIGN_CORNERS = 1, HALF_PIXEL = 2, CROP_AND_RESIZE = 3, }; enum class ResizeMethod : int64_t { UNKNOWN = -1, LINEAR = 0, NEAREST = 1, CUBIC = 2, }; enum class NearestMode : int64_t { NORMAL = 0, ROUND_HALF_DOWN = 1, ROUND_HALF_UP = 2, FLOOR = 3, CEIL = 4, }; enum RoundMode : int64_t { FLOOR = 0, CEIL = 1, }; enum ActivationType : int64_t { NO_ACTIVATION = 0, RELU = 1, SIGMOID = 2, RELU6 = 3, ELU = 4, LEAKY_RELU = 5, ABS = 6, RELU1 = 7, SOFTSIGN = 8, SOFTPLUS = 9, TANH = 10, SELU = 11, HSWISH = 12, HSIGMOID = 13, THRESHOLDRELU = 14, LINEAR = 15, HARD_TANH = 16, SIGN = 17, SWISH = 18, GELU = 19, GLU = 20, UNKNOWN = 21, }; enum ReduceMode : int64_t { Reduce_Mean = 0, Reduce_Max = 1, Reduce_Min = 2, Reduce_Prod = 3, Reduce_Sum = 4, Reduce_Sum_Square = 5, Reduce_ASum = 6, Reduce_All = 7, }; enum EltwiseMode : int64_t { PROD = 0, SUM = 1, MAXIMUM = 2, ELTWISEMODE_UNKNOW = 3, }; enum Reduction : int64_t { REDUCTION_SUM = 0, MEAN = 1, NONE = 2, }; enum PadMode : int64_t { PAD = 0, SAME = 1, VALID = 2, }; enum class LshProjectionType : int64_t { UNKNOWN = 0, SPARSE = 1, DENSE = 2, }; enum PaddingMode : int64_t { CONSTANT = 0, REFLECT = 1, SYMMETRIC = 2, MODE_RESERVED = 3, }; } // namespace mindspore #endif // MINDSPORE_CORE_MINDAPI_BASE_TYPES_H_ ================================================ FILE: tests/ut/stub/include/utils/log_adapter.cc ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/log_adapter.h" #define google mindspore_serving_private #ifndef _MSC_VER #include #include #endif #include #include #include // namespace to support utils module definition namespace mindspore { // set default log level to WARNING for all sub modules int g_ms_submodule_log_levels[NUM_SUBMODUES] = {INFO}; static std::string GetProcName() { #if defined(__APPLE__) || defined(__FreeBSD__) const std::string appname = getprogname(); #elif defined(_GNU_SOURCE) const std::string appname = program_invocation_name; #else const std::string appname = "?"; #endif // sometimes, the app name is an absolute path, it is too long std::string app_name(appname); std::size_t pos = app_name.rfind("/"); if (pos == std::string::npos) { return app_name; } if (pos + 1 >= app_name.size()) { return app_name; } return app_name.substr(pos + 1); } static std::string GetLogLevel(MsLogLevel level) { #define _TO_STRING(x) #x static const char *const level_names[] = { _TO_STRING(DEBUG), _TO_STRING(INFO), _TO_STRING(WARNING), _TO_STRING(ERROR), }; #undef _TO_STRING if (level > ERROR) { level = ERROR; } return std::string(level_names[level]); } // convert MsLogLevel to corresponding glog level static int GetGlogLevel(MsLogLevel level) { switch (level) { case DEBUG: case INFO: return google::GLOG_INFO; case WARNING: return google::GLOG_WARNING; case ERROR: default: return google::GLOG_ERROR; } } // get threshold level static int GetThresholdLevel(const std::string &threshold) { if (threshold.empty()) { return google::GLOG_WARNING; } else if (threshold == std::to_string(DEBUG) || threshold == std::to_string(INFO)) { return google::GLOG_INFO; } else if (threshold == std::to_string(WARNING)) { return google::GLOG_WARNING; } else if (threshold == std::to_string(ERROR)) { return google::GLOG_ERROR; } else { return google::GLOG_WARNING; } } void LogWriter::OutputLog(const std::ostringstream &msg) const { auto submodule_name = GetSubModuleName(submodule_); google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() #ifdef _MSC_VER << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << "," << std::hex #else << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex #endif << std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " " << "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl; } void LogWriter::operator<(const LogStream &stream) const noexcept { std::ostringstream msg; msg << stream.sstream_->rdbuf(); OutputLog(msg); } void LogWriter::operator^(const LogStream &stream) const { std::ostringstream msg; msg << stream.sstream_->rdbuf(); std::ostringstream oss; oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; oss << msg.str(); thread_local bool running = false; if (!running) { running = true; OutputLog(msg); if (trace_provider_ != nullptr) { trace_provider_(oss); } running = false; } if (exception_handler_ != nullptr) { exception_handler_(exception_type_, oss.str()); } throw std::runtime_error(oss.str()); } static std::string GetEnv(const std::string &envvar) { const char *value = ::getenv(envvar.c_str()); if (value == nullptr) { return std::string(); } return std::string(value); } enum class LogConfigToken : size_t { INVALID, // indicate invalid token LEFT_BRACE, // '{' RIGHT_BRACE, // '}' VARIABLE, // '[A-Za-z][A-Za-z0-9_]*' NUMBER, // [0-9]+ COMMA, // ',' COLON, // ':' EOS, // End Of String, '\0' NUM_LOG_CFG_TOKENS }; static const char *g_tok_names[static_cast(LogConfigToken::NUM_LOG_CFG_TOKENS)] = { "invalid", // indicate invalid token "{", // '{' "}", // '}' "variable", // '[A-Za-z][A-Za-z0-9_]*' "number", // [0-9]+ ",", // ',' ":", // ':' "end-of-string", // End Of String, '\0' }; static inline bool IsAlpha(char ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); } static inline bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } class LogConfigLexer { public: explicit LogConfigLexer(const std::string &text) : buffer_(text), cur_idx_(0) {} ~LogConfigLexer() = default; // skip white space, and return the first char after white space char SkipWhiteSpace() { while (cur_idx_ < buffer_.size()) { char ch = buffer_[cur_idx_]; if (ch == ' ' || ch == '\t') { ++cur_idx_; continue; } return ch; } return '\0'; } LogConfigToken GetNext(std::string *const ptr) { #ifdef DEBUG std::string text; auto tok = GetNextInner(&text); MS_LOG(DEBUG) << "Got token " << tok << " with value [" << text << "]"; if (ptr != nullptr) { *ptr = text; } return tok; } LogConfigToken GetNextInner(std::string *ptr) { #endif char ch = SkipWhiteSpace(); // clang-format off static const std::map single_char_map = { {'{', LogConfigToken::LEFT_BRACE}, {'}', LogConfigToken::RIGHT_BRACE}, {',', LogConfigToken::COMMA}, {':', LogConfigToken::COLON}, {'\0', LogConfigToken::EOS}, }; // clang-format on auto iter = single_char_map.find(ch); if (iter != single_char_map.end()) { if (ptr != nullptr) { *ptr = std::string() + ch; } ++cur_idx_; return iter->second; } else if (IsAlpha(ch)) { std::ostringstream oss; do { oss << ch; ch = buffer_[++cur_idx_]; } while (cur_idx_ < buffer_.size() && (IsAlpha(ch) || IsDigit(ch) || ch == '_')); if (ptr != nullptr) { *ptr = std::string(oss.str()); } return LogConfigToken::VARIABLE; } else if (IsDigit(ch)) { std::ostringstream oss; do { oss << ch; ch = buffer_[++cur_idx_]; } while (cur_idx_ < buffer_.size() && IsDigit(ch)); if (ptr != nullptr) { *ptr = std::string(oss.str()); } return LogConfigToken::NUMBER; } return LogConfigToken::INVALID; } private: std::string buffer_; size_t cur_idx_; }; class LogConfigParser { public: explicit LogConfigParser(const std::string &cfg) : lexer(cfg) {} ~LogConfigParser() = default; bool Expect(LogConfigToken expected, LogConfigToken tok) const { if (expected != tok) { MS_LOG(WARNING) << "Parse submodule log configuration text error, expect `" << g_tok_names[static_cast(expected)] << "`, but got `" << g_tok_names[static_cast(tok)] << "`. The whole configuration will be ignored."; return false; } return true; } // The text of config MS_SUBMODULE_LOG_v is in the form {submodule1:log_level1,submodule2:log_level2,...}. // Valid values of log levels are: 0 - debug, 1 - info, 2 - warning, 3 - error // e.g. MS_SUBMODULE_LOG_v={PARSER:0, ANALYZER:2, PIPELINE:1} std::map Parse() { std::map log_levels; bool flag_error = false; std::string text; auto tok = lexer.GetNext(&text); // empty string if (tok == LogConfigToken::EOS) { return log_levels; } if (!Expect(LogConfigToken::LEFT_BRACE, tok)) { return log_levels; } do { std::string key, val; tok = lexer.GetNext(&key); if (!Expect(LogConfigToken::VARIABLE, tok)) { flag_error = true; break; } tok = lexer.GetNext(&text); if (!Expect(LogConfigToken::COLON, tok)) { flag_error = true; break; } tok = lexer.GetNext(&val); if (!Expect(LogConfigToken::NUMBER, tok)) { flag_error = true; break; } log_levels[key] = val; tok = lexer.GetNext(&text); } while (tok == LogConfigToken::COMMA); if (!flag_error && !Expect(LogConfigToken::RIGHT_BRACE, tok)) { flag_error = true; } if (flag_error) { log_levels.clear(); } return log_levels; } private: LogConfigLexer lexer; }; bool ParseLogLevel(const std::string &str_level, MsLogLevel *ptr_level) { constexpr char number_start = '0'; if (str_level.size() == 1) { int ch = str_level.c_str()[0]; ch = ch - number_start; // subtract ASCII code of '0', which is 48 if (ch >= DEBUG && ch <= ERROR) { if (ptr_level != nullptr) { *ptr_level = static_cast(ch); } return true; } } return false; } static MsLogLevel GetGlobalLogLevel() { return static_cast(FLAGS_v); } void InitSubModulesLogLevel() { // initialize submodule's log level using global auto global_log_level = GetGlobalLogLevel(); for (int i = 0; i < NUM_SUBMODUES; ++i) { g_ms_submodule_log_levels[i] = global_log_level; } // set submodule's log level auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; LogConfigParser parser(submodule); auto configs = parser.Parse(); for (const auto &cfg : configs) { int mod_idx = -1; for (int i = 0; i < NUM_SUBMODUES; ++i) { if (cfg.first == GetSubModuleName(static_cast(i))) { mod_idx = i; break; } } if (mod_idx < 0) { MS_LOG(WARNING) << "Undefined module name " << cfg.first << ", ignore it"; continue; } MsLogLevel submodule_log_level; if (!ParseLogLevel(cfg.second, &submodule_log_level)) { MS_LOG(WARNING) << "Illegal log level value " << cfg.second << " for " << cfg.first << ", ignore it."; continue; } g_ms_submodule_log_levels[mod_idx] = submodule_log_level; } } } // namespace mindspore extern "C" { #if defined(_WIN32) || defined(_WIN64) #ifdef _MSC_VER void common_log_init(void) { #else __attribute__((constructor)) void common_log_init(void) { #endif #else void common_log_init(void) { #endif // Do not use glog predefined log prefix FLAGS_log_prefix = false; // Write log to files real-time FLAGS_logbufsecs = 0; // Set default log level to WARNING if (mindspore::GetEnv("GLOG_v").empty()) { FLAGS_v = mindspore::WARNING; } // Set default log file mode to 0640 if (mindspore::GetEnv("GLOG_logfile_mode").empty()) { FLAGS_logfile_mode = 0640; } // Set default log file max size to 50 MB FLAGS_max_log_size = 50; std::string max_log_size = mindspore::GetEnv("GLOG_max_log_size"); if (!max_log_size.empty()) { FLAGS_max_log_size = std::stoi(max_log_size); } std::string logtostderr = mindspore::GetEnv("GLOG_logtostderr"); // Default print log to screen if (logtostderr.empty()) { FLAGS_logtostderr = true; } else if (logtostderr == "0") { if (mindspore::GetEnv("GLOG_log_dir").empty()) { MS_LOG(ERROR) << "`GLOG_log_dir` is empty, it must be set while 'logtostderr' equals to 0."; // Here can not throw exception and use python to catch, because the PYBIND11_MODULE is not yet been initialed. exit(EXIT_FAILURE); } else { // Set log dir from GLOG_log_dir with RANK_ID or OMPI_COMM_WORLD_RANK. std::string rank_id = mindspore::GetEnv("RANK_ID"); std::string gpu_rank_id = mindspore::GetEnv("OMPI_COMM_WORLD_RANK"); std::string rank = "0"; if ((!rank_id.empty() && gpu_rank_id.empty()) || (!rank_id.empty() && !gpu_rank_id.empty())) { rank = rank_id; } else if (rank_id.empty() && !gpu_rank_id.empty()) { rank = gpu_rank_id; } FLAGS_log_dir = mindspore::GetEnv("GLOG_log_dir") + "/rank_" + rank + "/logs"; } } // Default GLOG_stderrthreshold level to WARNING auto threshold = mindspore::GetEnv("GLOG_stderrthreshold"); FLAGS_stderrthreshold = mindspore::GetThresholdLevel(threshold); mindspore::InitSubModulesLogLevel(); } } #undef google ================================================ FILE: tests/ut/stub/include/utils/log_adapter.h ================================================ /** * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ #define MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ #include #include #include #include #include #include #include #include #include "utils/visible.h" #include "utils/overload.h" #include "./securec.h" #define google mindspore_serving_private #include "glog/logging.h" #undef google // NOTICE: when relative path of 'log_adapter.h' changed, macro 'LOG_HDR_FILE_REL_PATH' must be changed #define LOG_HDR_FILE_REL_PATH "mindspore/core/utils/log_adapter.h" // Get start index of file relative path in __FILE__ static constexpr size_t GetRelPathPos() noexcept { return sizeof(__FILE__) > sizeof(LOG_HDR_FILE_REL_PATH) ? sizeof(__FILE__) - sizeof(LOG_HDR_FILE_REL_PATH) : 0; } namespace mindspore { MS_CORE_API extern std::map acl_handle_map; #define FILE_NAME \ (sizeof(__FILE__) > GetRelPathPos() ? static_cast(__FILE__) + GetRelPathPos() \ : static_cast(__FILE__)) enum ExceptionType { NoExceptionType = 0, UnknownError, ArgumentError, NotSupportError, NotExistsError, AlreadyExistsError, UnavailableError, DeviceProcessError, AbortedError, TimeOutError, ResourceUnavailable, NoPermissionError, IndexError, ValueError, TypeError, KeyError, AttributeError, NameError }; struct LocationInfo { LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} ~LocationInfo() = default; const char *file_; int line_; const char *func_; }; class LogStream { public: LogStream() { sstream_ = std::make_shared(); } ~LogStream() = default; template LogStream &operator<<(const T &val) noexcept { (*sstream_) << val; return *this; } LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { (*sstream_) << func; return *this; } friend class LogWriter; private: std::shared_ptr sstream_; }; template ::value, int>::type = 0> constexpr std::ostream &operator<<(std::ostream &stream, const T &value) { return stream << static_cast::type>(value); } enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; enum SubModuleId : int { SM_UNKNOWN = 0, // unknown submodule SM_CORE, // core SM_ANALYZER, // static analyzer SM_COMMON, // common SM_DEBUG, // debug SM_OFFLINE_DEBUG, // offline debug SM_DEVICE, // device SM_GE_ADPT, // ge adapter SM_IR, // IR SM_KERNEL, // kernel SM_MD, // MindData SM_ME, // MindExpression SM_EXPRESS, // EXPRESS_IR SM_OPTIMIZER, // optimzer SM_PARALLEL, // parallel SM_PARSER, // parser SM_PIPELINE, // ME pipeline SM_PRE_ACT, // pre-activate SM_PYNATIVE, // PyNative SM_SESSION, // session SM_UTILS, // utils SM_VM, // VM SM_PROFILER, // profiler SM_PS, // Parameter Server SM_FL, // Federated Learning SM_LITE, // LITE SM_ARMOUR, // ARMOUR SM_HCCL_ADPT, // Hccl Adapter SM_MINDQUANTUM, // MindQuantum SM_RUNTIME_FRAMEWORK, // Runtime framework SM_GE, // GraphEngine NUM_SUBMODUES // number of submodules }; #ifndef SUBMODULE_ID #define SUBMODULE_ID mindspore::SubModuleId::SM_ME #endif MS_EXPORT const std::string GetSubModuleName(SubModuleId module_id); const char *EnumStrForMsLogLevel(MsLogLevel level); MS_EXPORT std::string GetTimeString(); MS_EXPORT extern int g_ms_submodule_log_levels[]; class LogWriter { public: using ExceptionHandler = std::function; using TraceProvider = std::function; LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, ExceptionType excp_type = NoExceptionType) : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} ~LogWriter() = default; MS_CORE_API void operator<(const LogStream &stream) const noexcept; MS_CORE_API void operator^(const LogStream &stream) const __attribute__((noreturn)); static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } static void set_trace_provider(TraceProvider trace_provider) { trace_provider_ = trace_provider; } static TraceProvider trace_provider() { return trace_provider_; } private: void OutputLog(const std::ostringstream &msg) const; LocationInfo location_; MsLogLevel log_level_; SubModuleId submodule_; ExceptionType exception_type_; inline static ExceptionHandler exception_handler_ = nullptr; inline static TraceProvider trace_provider_ = nullptr; }; #define MSLOG_IF(level, condition, excp_type) \ static_cast(0), !(condition) \ ? void(0) \ : mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, \ SUBMODULE_ID, excp_type) < mindspore::LogStream() #define MSLOG_THROW(excp_type) \ mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \ excp_type) ^ \ mindspore::LogStream() #define IS_OUTPUT_ON(level) ((level) >= mindspore::g_ms_submodule_log_levels[SUBMODULE_ID]) #define MS_LOG(level) MS_LOG_##level #define MS_LOG_DEBUG MSLOG_IF(mindspore::DEBUG, IS_OUTPUT_ON(mindspore::DEBUG), mindspore::NoExceptionType) #define MS_LOG_INFO MSLOG_IF(mindspore::INFO, IS_OUTPUT_ON(mindspore::INFO), mindspore::NoExceptionType) #define MS_LOG_WARNING MSLOG_IF(mindspore::WARNING, IS_OUTPUT_ON(mindspore::WARNING), mindspore::NoExceptionType) #define MS_LOG_ERROR MSLOG_IF(mindspore::ERROR, IS_OUTPUT_ON(mindspore::ERROR), mindspore::NoExceptionType) #define MS_LOG_EXCEPTION MSLOG_THROW(mindspore::NoExceptionType) #define MS_EXCEPTION(type) MSLOG_THROW(type) } // namespace mindspore #define MS_EXCEPTION_IF_NULL(ptr) \ do { \ if ((ptr) == nullptr) { \ MS_LOG(EXCEPTION) << ": The pointer[" << #ptr << "] is null."; \ } \ } while (0) #define MS_EXCEPTION_IF_ZERO(name, value) \ do { \ if (value == 0) { \ MS_LOG(EXCEPTION) << ": The " << name << " is zero."; \ } \ } while (0) #define MS_ERROR_IF_NULL(ptr) \ do { \ if ((ptr) == nullptr) { \ MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ return false; \ } \ } while (0) #define MS_ERROR_IF_NULL_W_RET_VAL(ptr, val) \ do { \ if ((ptr) == nullptr) { \ MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ return val; \ } \ } while (0) #define MS_ERROR_IF_NULL_WO_RET_VAL(ptr) \ do { \ if ((ptr) == nullptr) { \ MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ return; \ } \ } while (0) #ifdef DEBUG #include #define MS_ASSERT(f) assert(f) #else #define MS_ASSERT(f) ((void)0) #endif #endif // MINDSPORE_CORE_UTILS_LOG_ADAPTER_H_ ================================================ FILE: tests/ut/stub/include/utils/log_adapter_common.cc ================================================ /** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef _MSC_VER #include #endif #include #include #include #include "utils/log_adapter.h" namespace mindspore { static const std::vector sub_module_names = { "UNKNOWN", // SM_UNKNOWN "CORE", // SM_CORE "ANALYZER", // SM_ANALYZER "COMMON", // SM_COMMON "DEBUG", // SM_DEBUG "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG "DEVICE", // SM_DEVICE "GE_ADPT", // SM_GE_ADPT "IR", // SM_IR "KERNEL", // SM_KERNEL "MD", // SM_MD "ME", // SM_ME "EXPRESS", // SM_EXPRESS "OPTIMIZER", // SM_OPTIMIZER "PARALLEL", // SM_PARALLEL "PARSER", // SM_PARSER "PIPELINE", // SM_PIPELINE "PRE_ACT", // SM_PRE_ACT "PYNATIVE", // SM_PYNATIVE "SESSION", // SM_SESSION "UTILS", // SM_UTILS "VM", // SM_VM "PROFILER", // SM_PROFILER "PS", // SM_PS "FL", // SM_FL "LITE", // SM_LITE "ARMOUR", // SM_ARMOUR "HCCL_ADPT", // SM_HCCL_ADPT "MINDQUANTUM", // SM_MINDQUANTUM "RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK "GE", // SM_GE }; const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_names[static_cast(module_id % NUM_SUBMODUES)]; } std::string GetTimeString() { #if defined(_WIN32) || defined(_WIN64) time_t time_seconds = time(0); struct tm now_time; localtime_s(&now_time, &time_seconds); constexpr int base_year = 1900; std::stringstream ss; ss << now_time.tm_year + base_year << "-" << now_time.tm_mon + 1 << "-" << now_time.tm_mday << " " << now_time.tm_hour << ":" << now_time.tm_min << ":" << now_time.tm_sec; return ss.str(); #else constexpr auto BUFLEN = 80; char buf[BUFLEN] = {'\0'}; struct timeval cur_time; (void)gettimeofday(&cur_time, nullptr); struct tm now; constexpr int width = 3; constexpr int64_t time_convert_unit = 1000; (void)localtime_r(&cur_time.tv_sec, &now); (void)strftime(buf, BUFLEN, "%Y-%m-%d-%H:%M:%S", &now); // format date and time std::stringstream ss; ss << "." << std::setfill('0') << std::setw(width) << cur_time.tv_usec / time_convert_unit << "." << std::setfill('0') << std::setw(width) << cur_time.tv_usec % time_convert_unit; return std::string(buf) + ss.str(); #endif } } // namespace mindspore ================================================ FILE: tests/ut/stub/include/utils/overload.h ================================================ /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_UTILS_OVERLOAD_H_ #define MINDSPORE_CORE_UTILS_OVERLOAD_H_ #include #include #include #include #include #include #include #include #include namespace mindspore { template std::ostream &operator<<(std::ostream &out, const std::vector &v) { out << "[const vector]["; size_t last = v.size() - 1; for (size_t i = 0; i < v.size(); ++i) { out << v[i]; if (i != last) out << ", "; } out << "]"; return out; } template std::ostream &operator<<(std::ostream &os, const std::list &vec) { bool begin = true; os << "[const list]["; for (auto &item : vec) { if (!begin) { os << ", "; } else { begin = false; } os << item; } os << "]"; return os; } template std::ostream &operator<<(std::ostream &os, const std::initializer_list &vec) { bool begin = true; os << "["; for (auto &item : vec) { if (!begin) { os << ", "; } else { begin = false; } os << item; } os << "]"; return os; } template bool operator==(const std::initializer_list &lhs, const std::initializer_list &rhs) { if (lhs.size() != rhs.size()) { return false; } auto lit = lhs.begin(); auto rit = rhs.begin(); while (lit != lhs.end()) { if (!(*lit == *rit)) { return false; } lit++; rit++; } return true; } template std::ostream &operator<<(std::ostream &os, const std::pair &pair) { os << "[const pair]"; return os; } template std::ostream &operator<<(std::ostream &os, const std::unordered_map &map) { os << "[const unordered_map]"; return os; } template std::ostream &operator<<(std::ostream &os, const std::map &map) { os << "[const map]"; return os; } template std::string ToString(const std::vector &vec) { std::ostringstream buffer; buffer << vec; return buffer.str(); } template std::string ToString(const std::unordered_map &map) { std::ostringstream buffer; buffer << map; return buffer.str(); } template std::string ToString(const std::map &map) { std::ostringstream buffer; buffer << map; return buffer.str(); } } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_OVERLOAD_H_ ================================================ FILE: tests/ut/stub/include/utils/utils.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_STUB_SERVING_UTILS_H #define MINDSPORE_STUB_SERVING_UTILS_H #include #include #include #include #include #include #include #include "utils/log_adapter.h" namespace mindspore { class FuncGraph { public: explicit FuncGraph(const std::string &file_name) : file_name_(file_name) {} const std::string file_name_; }; using FuncGraphPtr = std::shared_ptr; namespace common { static inline const char *SafeCStr(const std::string &str) { const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_MASK = CACHED_STR_NUM - 1; std::vector STR_HOLDER(CACHED_STR_NUM); static std::atomic index{0}; uint32_t cur_index = index++; cur_index = cur_index & CACHED_STR_MASK; STR_HOLDER[cur_index] = str; return STR_HOLDER[cur_index].c_str(); } static inline bool DirOrFileExist(const std::string &file_path) { int ret = access(file_path.c_str(), 0); return ret != -1; } } // namespace common static inline size_t IntToSize(int i) { return static_cast(i); } typedef unsigned char Byte; static inline std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key, const size_t key_len, const std::string &dec_mode) { auto bytes = new Byte[10]; return std::unique_ptr(bytes); } static inline std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key, const size_t key_len, const std::string &dec_mode) { auto bytes = new Byte[10]; return std::unique_ptr(bytes); } static inline bool IsCipherFile(const std::string &file_path) { return false; } static inline bool IsCipherFile(const Byte *model_data) { return false; } static inline std::shared_ptr LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key, const size_t key_len, const std::string &dec_mode) { std::ifstream ifs(file_name); if (!ifs.good()) { MS_LOG(ERROR) << "File: " << file_name << " is not exist"; return nullptr; } if (!ifs.is_open()) { MS_LOG(ERROR) << "File: " << file_name << "open failed"; return nullptr; } return std::make_shared(file_name); } static inline std::vector> LoadMindIRs( const std::vector file_names, bool is_lite = false, const unsigned char *dec_key = nullptr, const size_t key_len = 0, const std::string &dec_mode = std::string("AES-GCM")) { std::vector> graphs; for (auto &file_name : file_names) { std::ifstream ifs(file_name); if (!ifs.good()) { MS_LOG(ERROR) << "File: " << file_name << " is not exist"; return {}; } if (!ifs.is_open()) { MS_LOG(ERROR) << "File: " << file_name << "open failed"; return {}; } graphs.push_back(std::make_shared(file_name)); } return graphs; } static inline std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false) { return std::make_shared(""); } class MSTensor::Impl { public: Impl() = default; virtual ~Impl() = default; virtual const std::string &Name() const = 0; virtual enum DataType DataType() const = 0; virtual const std::vector &Shape() const = 0; virtual std::shared_ptr Data() const = 0; virtual void *MutableData() = 0; virtual size_t DataSize() const = 0; virtual bool IsDevice() const = 0; virtual std::shared_ptr Clone() const = 0; }; } // namespace mindspore #endif // MINDSPORE_STUB_SERVING_UTILS_H ================================================ FILE: tests/ut/stub/include/utils/visible.h ================================================ /** * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CORE_UTILS_VISIBLE_H_ #define MINDSPORE_CORE_UTILS_VISIBLE_H_ #if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__)) #ifdef BUILDING_DLL #define MS_CORE_API __declspec(dllexport) #define MS_EXPORT __declspec(dllexport) #else #define MS_CORE_API __declspec(dllimport) #define MS_EXPORT __declspec(dllimport) #endif #define MS_LOCAL #else #define MS_CORE_API __attribute__((visibility("default"))) #define MS_EXPORT __attribute__((visibility("default"))) #define MS_LOCAL __attribute__((visibility("hidden"))) #endif #endif // MINDSPORE_CORE_UTILS_VISIBLE_H_ ================================================ FILE: tests/ut/stub/stub_inference.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "worker/inference/inference.h" #include "worker/inference/mindspore_model_wrap.h" namespace mindspore::serving { InferenceLoader::InferenceLoader() {} InferenceLoader::~InferenceLoader() {} std::string ModelContext::AsString() const { std::map output_map; if (thread_num > -1) { output_map["thread num"] = AsStringHelper::AsString(thread_num); } if (!thread_affinity_core_list.empty()) { output_map["thread affinity core list"] = AsStringHelper::AsString(thread_affinity_core_list); } if (enable_parallel > -1) { output_map["enable parallel"] = AsStringHelper::AsString(enable_parallel); } if (!device_list.empty()) { output_map["device infos"] = AsStringHelper::AsString(device_list); } return AsStringHelper::AsString(output_map); } InferenceLoader &InferenceLoader::Instance() { static InferenceLoader inference; return inference; } std::shared_ptr InferenceLoader::CreateMindSporeInfer() { return std::make_shared(); } Status InferenceLoader::LoadMindSporeModelWrap() { return SUCCESS; } bool InferenceLoader::GetEnableLite() const { return enable_lite_; } DeviceType InferenceLoader::GetSupportDeviceType(DeviceType device_type, ModelType model_type) { auto mindspore_infer = CreateMindSporeInfer(); if (mindspore_infer == nullptr) { MSI_LOG_ERROR << "Create MindSpore infer failed"; return kDeviceTypeNotSpecified; } std::vector check_model_types; if (model_type == kUnknownType) { check_model_types = {kMindIR, kMindIR_Lite, kOM}; } else { check_model_types = {model_type}; } for (auto &model_type_item : check_model_types) { if (device_type == kDeviceTypeNotSpecified) { auto device_list = {kDeviceTypeAscend, kDeviceTypeGpu, kDeviceTypeCpu}; for (auto item : device_list) { if (mindspore_infer->CheckModelSupport(item, model_type_item)) { return item; } } } else { if (mindspore_infer->CheckModelSupport(device_type, model_type_item)) { return device_type; } } } return kDeviceTypeNotSpecified; } bool InferenceLoader::SupportReuseDevice() { auto mindspore_infer = CreateMindSporeInfer(); if (mindspore_infer == nullptr) { MSI_LOG_ERROR << "Create MindSpore infer failed"; return false; } return mindspore_infer->SupportReuseDevice(); } } // namespace mindspore::serving ================================================ FILE: tests/ut/stub/stub_postprocess.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/stage_function.h" #include "mindspore_serving/ccsrc/common/tensor.h" namespace mindspore::serving { class StubCastFp32toInt32Postprocess : public CppStageFunctionBase { public: Status Call(const std::string &postprocess_name, const InstanceData &input, InstanceData *output) override { MSI_EXCEPTION_IF_NULL(output); auto x1 = input[0]; if (x1->data_type() != kMSI_Float32) { return INFER_STATUS_LOG_ERROR(FAILED) << "Postprocess failed: Input data type invalid " << x1->data_type(); } auto y1 = std::make_shared(); y1->set_data_type(kMSI_Int32); y1->resize_data(x1->data_size()); y1->set_shape(x1->shape()); output->push_back(y1); auto x1_data = reinterpret_cast(x1->data()); auto y1_data = reinterpret_cast(y1->mutable_data()); for (size_t i = 0; i < y1->data_size() / 4; i++) { y1_data[i] = static_cast(x1_data[i]); } return SUCCESS; } size_t GetInputsCount(const std::string &postprocess_name) const override { return 1; } size_t GetOutputsCount(const std::string &postprocess_name) const override { return 1; } }; REGISTER_STAGE_FUNCTION(StubCastFp32toInt32Postprocess, "stub_postprocess_cast_fp32_to_int32_cpp") } // namespace mindspore::serving ================================================ FILE: tests/ut/stub/stub_preprocess.cc ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "worker/stage_function.h" #include "mindspore_serving/ccsrc/common/tensor.h" namespace mindspore::serving { class StubCastInt32toFp32Preprocess : public CppStageFunctionBase { public: Status Call(const std::string &postprocess_name, const InstanceData &input, InstanceData *output) override { MSI_EXCEPTION_IF_NULL(output); auto x1 = input[0]; auto x2 = input[1]; if (x1->data_type() != kMSI_Int32 || x2->data_type() != kMSI_Int32) { return INFER_STATUS_LOG_ERROR(FAILED) << "Call failed: Input data type invalid " << x1->data_type() << ", " << x2->data_type(); } auto y1 = std::make_shared(); y1->set_data_type(serving::kMSI_Float32); y1->resize_data(x1->data_size()); y1->set_shape(x1->shape()); output->push_back(y1); auto y2 = std::make_shared(); y2->set_data_type(serving::kMSI_Float32); y2->resize_data(x2->data_size()); y2->set_shape(x2->shape()); output->push_back(y2); auto x1_data = reinterpret_cast(x1->data()); auto y1_data = reinterpret_cast(y1->mutable_data()); for (size_t i = 0; i < y1->data_size() / 4; i++) { y1_data[i] = static_cast(x1_data[i]); } auto x2_data = reinterpret_cast(x2->data()); auto y2_data = reinterpret_cast(y2->mutable_data()); for (size_t i = 0; i < y2->data_size() / 4; i++) { y2_data[i] = static_cast(x2_data[i]); } return SUCCESS; } size_t GetInputsCount(const std::string &postprocess_name) const override { return 2; } size_t GetOutputsCount(const std::string &postprocess_name) const override { return 2; } }; REGISTER_STAGE_FUNCTION(StubCastInt32toFp32Preprocess, "stub_preprocess_cast_int32_to_fp32_cpp") } // namespace mindspore::serving ================================================ FILE: third_party/patch/c-ares/CVE-2021-3672.patch ================================================ diff -Npur c-ares-1.15.0/ares_expand_name.c c-ares-1.15.0-new/ares_expand_name.c --- c-ares-1.15.0/ares_expand_name.c 2017-07-03 17:04:19.000000000 +0800 +++ c-ares-1.15.0-new/ares_expand_name.c 2021-08-21 22:48:24.650973166 +0800 @@ -38,6 +38,26 @@ static int name_length(const unsigned char *encoded, const unsigned char *abuf, int alen); +/* Reserved characters for names that need to be escaped */ +static int is_reservedch(int ch) +{ + switch (ch) { + case '"': + case '.': + case ';': + case '\\': + case '(': + case ')': + case '@': + case '$': + return 1; + default: + break; + } + + return 0; +} + /* Expand an RFC1035-encoded domain name given by encoded. The * containing message is given by abuf and alen. The result given by * *s, which is set to a NUL-terminated allocated buffer. *enclen is @@ -113,18 +133,37 @@ int ares_expand_name(const unsigned char } else { - len = *p; + int name_len = *p; + len = name_len; p++; + while (len--) { - if (*p == '.' || *p == '\\') - *q++ = '\\'; - *q++ = *p; + /* Output as \DDD for consistency with RFC1035 5.1, except + * for the special case of a root name response */ + if (!isprint(*p) && !(name_len == 1 && *p == 0)) + { + + *q++ = '\\'; + *q++ = '0' + *p / 100; + *q++ = '0' + (*p % 100) / 10; + *q++ = '0' + (*p % 10); + } + else if (is_reservedch(*p)) + { + *q++ = '\\'; + *q++ = *p; + } + else + { + *q++ = *p; + } p++; } *q++ = '.'; } - } + } + if (!indir) *enclen = aresx_uztosl(p + 1U - encoded); @@ -171,15 +210,29 @@ static int name_length(const unsigned ch } else if (top == 0x00) { - offset = *encoded; + int name_len = *encoded; + offset = name_len; if (encoded + offset + 1 >= abuf + alen) return -1; encoded++; + while (offset--) { - n += (*encoded == '.' || *encoded == '\\') ? 2 : 1; + if (!isprint(*encoded) && !(name_len == 1 && *encoded == 0)) + { + n += 4; + } + else if (is_reservedch(*encoded)) + { + n += 2; + } + else + { + n += 1; + } encoded++; } + n++; } else ================================================ FILE: third_party/patch/glog/glog.patch001 ================================================ diff -Npur glog/CMakeLists.txt glog-patch/CMakeLists.txt --- glog/CMakeLists.txt 2019-03-22 10:51:46.000000000 +0800 +++ glog-patch/CMakeLists.txt 2021-04-01 10:32:25.753140500 +0800 @@ -64,7 +64,6 @@ check_include_file (dlfcn.h HAVE_DLFCN_H check_include_file (execinfo.h HAVE_EXECINFO_H) check_include_file (glob.h HAVE_GLOB_H) check_include_file (inttypes.h HAVE_INTTYPES_H) -check_include_file (libunwind.h HAVE_LIBUNWIND_H) check_include_file (memory.h HAVE_MEMORY_H) check_include_file (pwd.h HAVE_PWD_H) check_include_file (stdint.h HAVE_STDINT_H) @@ -80,7 +79,6 @@ check_include_file (syscall.h HAVE_SYSCA check_include_file (syslog.h HAVE_SYSLOG_H) check_include_file (ucontext.h HAVE_UCONTEXT_H) check_include_file (unistd.h HAVE_UNISTD_H) -check_include_file (unwind.h HAVE_UNWIND_H) check_include_file (pwd.h HAVE_PWD_H) check_include_file_cxx ("ext/hash_map" HAVE_EXT_HASH_MAP) @@ -116,12 +114,8 @@ check_cxx_compiler_flag (-Wunnamed-type- # snprintf as an inline function check_symbol_exists (snprintf stdio.h HAVE_SNPRINTF) -check_library_exists (unwind get_static_proc_name "" HAVE_LIB_UNWIND) check_library_exists (dbghelp UnDecorateSymbolName "" HAVE_DBGHELP) -find_library (UNWIND_LIBRARY NAMES unwind DOC "unwind library") -mark_as_advanced (UNWIND_LIBRARY) - check_c_source_compiles (" #include static void foo(void) __attribute__ ((unused)); @@ -470,10 +464,7 @@ add_library (glog add_library(glog::glog ALIAS glog) set_target_properties (glog PROPERTIES POSITION_INDEPENDENT_CODE ON) - -if (UNWIND_LIBRARY) - target_link_libraries (glog PUBLIC ${UNWIND_LIBRARY}) -endif (UNWIND_LIBRARY) +set_target_properties (glog PROPERTIES OUTPUT_NAME mindspore_serving_glog) if (HAVE_DBGHELP) target_link_libraries (glog PUBLIC dbghelp) ================================================ FILE: third_party/patch/grpc/grpc.patch001 ================================================ diff -Npur grpc/..rej grpc-patch/..rej --- grpc/..rej 1970-01-01 08:00:00.000000000 +0800 +++ grpc-patch/..rej 2021-04-22 21:00:17.343178600 +0800 @@ -0,0 +1,22 @@ +--- CMakeLists.txt 2020-02-27 03:12:33.000000000 +0800 ++++ CMakeLists.txt 2021-04-07 21:27:12.317207600 +0800 +@@ -12992,7 +12992,7 @@ if(gRPC_BUILD_CODEGEN AND gRPC_BUILD_GRP + add_executable(grpc_cpp_plugin + src/compiler/cpp_plugin.cc + ) +- ++set_target_properties(grpc_cpp_plugin PROPERTIES INSTALL_RPATH $ORIGIN/../lib) + target_include_directories(grpc_cpp_plugin + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +--- cmake/cares.cmake 2020-02-27 03:12:33.000000000 +0800 ++++ cmake/cares.cmake 2021-04-10 14:22:35.895725700 +0800 +@@ -39,7 +39,7 @@ if(gRPC_CARES_PROVIDER STREQUAL "module" + set(gRPC_INSTALL FALSE) + endif() + elseif(gRPC_CARES_PROVIDER STREQUAL "package") +- find_package(c-ares 1.13.0 REQUIRED) ++ find_package(c-ares REQUIRED) # cmake 3.19+ cannot find cares 1.15.0 + if(TARGET c-ares::cares) + set(_gRPC_CARES_LIBRARIES c-ares::cares) + endif() diff -Npur grpc/.rej grpc-patch/.rej --- grpc/.rej 1970-01-01 08:00:00.000000000 +0800 +++ grpc-patch/.rej 2021-04-22 21:03:38.192349100 +0800 @@ -0,0 +1,22 @@ +--- grpc/CMakeLists.txt 2020-02-27 03:12:33.000000000 +0800 ++++ grpc-patch/CMakeLists.txt 2021-04-07 21:27:12.317207600 +0800 +@@ -12992,7 +12992,7 @@ if(gRPC_BUILD_CODEGEN AND gRPC_BUILD_GRP + add_executable(grpc_cpp_plugin + src/compiler/cpp_plugin.cc + ) +- ++set_target_properties(grpc_cpp_plugin PROPERTIES INSTALL_RPATH $ORIGIN/../lib) + target_include_directories(grpc_cpp_plugin + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +--- grpc/cmake/cares.cmake 2020-02-27 03:12:33.000000000 +0800 ++++ grpc-patch/cmake/cares.cmake 2021-04-10 14:22:35.895725700 +0800 +@@ -39,7 +39,7 @@ if(gRPC_CARES_PROVIDER STREQUAL "module" + set(gRPC_INSTALL FALSE) + endif() + elseif(gRPC_CARES_PROVIDER STREQUAL "package") +- find_package(c-ares 1.13.0 REQUIRED) ++ find_package(c-ares REQUIRED) # cmake 3.19+ cannot find cares 1.15.0 + if(TARGET c-ares::cares) + set(_gRPC_CARES_LIBRARIES c-ares::cares) + endif() diff -Npur grpc/CMakeLists.txt grpc-patch/CMakeLists.txt --- grpc/CMakeLists.txt 2020-02-27 03:12:33.000000000 +0800 +++ grpc-patch/CMakeLists.txt 2021-04-22 21:15:04.458188400 +0800 @@ -936,6 +936,8 @@ set_target_properties(address_sorting PR SOVERSION ${gRPC_CORE_SOVERSION} ) +set_target_properties(address_sorting PROPERTIES OUTPUT_NAME mindspore_serving_address_sorting) + if(WIN32 AND MSVC) set_target_properties(address_sorting PROPERTIES COMPILE_PDB_NAME "address_sorting" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -1404,6 +1406,8 @@ set_target_properties(gpr PROPERTIES SOVERSION ${gRPC_CORE_SOVERSION} ) +set_target_properties(gpr PROPERTIES OUTPUT_NAME mindspore_serving_gpr) + if(WIN32 AND MSVC) set_target_properties(gpr PROPERTIES COMPILE_PDB_NAME "gpr" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -1869,6 +1873,8 @@ set_target_properties(grpc PROPERTIES SOVERSION ${gRPC_CORE_SOVERSION} ) +set_target_properties(grpc PROPERTIES OUTPUT_NAME mindspore_serving_grpc) + if(WIN32 AND MSVC) set_target_properties(grpc PROPERTIES COMPILE_PDB_NAME "grpc" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -3696,6 +3702,8 @@ set_target_properties(grpc++ PROPERTIES SOVERSION ${gRPC_CPP_SOVERSION} ) +set_target_properties(grpc++ PROPERTIES OUTPUT_NAME mindspore_serving_grpc++) + if(WIN32 AND MSVC) set_target_properties(grpc++ PROPERTIES COMPILE_PDB_NAME "grpc++" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -4279,6 +4287,8 @@ set_target_properties(grpc++_reflection SOVERSION ${gRPC_CPP_SOVERSION} ) +set_target_properties(grpc++_reflection PROPERTIES OUTPUT_NAME mindspore_serving_grpc++_reflection) + if(WIN32 AND MSVC) set_target_properties(grpc++_reflection PROPERTIES COMPILE_PDB_NAME "grpc++_reflection" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -5896,6 +5906,8 @@ set_target_properties(upb PROPERTIES SOVERSION ${gRPC_CORE_SOVERSION} ) +set_target_properties(upb PROPERTIES OUTPUT_NAME mindspore_serving_upb) + if(WIN32 AND MSVC) set_target_properties(upb PROPERTIES COMPILE_PDB_NAME "upb" COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" @@ -12992,7 +13004,7 @@ if(gRPC_BUILD_CODEGEN AND gRPC_BUILD_GRP add_executable(grpc_cpp_plugin src/compiler/cpp_plugin.cc ) - +set_target_properties(grpc_cpp_plugin PROPERTIES INSTALL_RPATH $ORIGIN/../lib) target_include_directories(grpc_cpp_plugin PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} @@ -13251,6 +13263,8 @@ add_executable(grpc_python_plugin src/compiler/python_plugin.cc ) +set_target_properties(grpc_python_plugin PROPERTIES INSTALL_RPATH $ORIGIN/../lib) + target_include_directories(grpc_python_plugin PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} diff -Npur grpc/cmake/cares.cmake grpc-patch/cmake/cares.cmake --- grpc/cmake/cares.cmake 2020-02-27 03:12:33.000000000 +0800 +++ grpc-patch/cmake/cares.cmake 2021-04-22 21:05:06.398251400 +0800 @@ -39,7 +39,7 @@ if(gRPC_CARES_PROVIDER STREQUAL "module" set(gRPC_INSTALL FALSE) endif() elseif(gRPC_CARES_PROVIDER STREQUAL "package") - find_package(c-ares 1.13.0 REQUIRED) + find_package(c-ares REQUIRED) # cmake 3.19+ cannot find cares 1.15.0 if(TARGET c-ares::cares) set(_gRPC_CARES_LIBRARIES c-ares::cares) endif() ================================================ FILE: third_party/patch/libevent/libevent.patch001 ================================================ diff -Npur libevent/CMakeLists.txt libevent-modify/CMakeLists.txt --- libevent/CMakeLists.txt 2020-07-05 20:02:46.000000000 +0800 +++ libevent-modify/CMakeLists.txt 2021-04-19 16:36:57.982307500 +0800 @@ -852,7 +852,7 @@ if (NOT EVENT__DISABLE_OPENSSL) list(APPEND SRC_OPENSSL bufferevent_openssl.c) list(APPEND HDR_PUBLIC include/event2/bufferevent_ssl.h) - list(APPEND LIB_APPS ${OPENSSL_LIBRARIES}) + list(APPEND LIB_APPS ${OPENSSL_LIBRARIES} -ldl) endif() if (NOT EVENT__DISABLE_THREAD_SUPPORT) diff -Npur libevent/cmake/AddEventLibrary.cmake libevent-modify/cmake/AddEventLibrary.cmake --- libevent/cmake/AddEventLibrary.cmake 2020-07-05 20:02:46.000000000 +0800 +++ libevent-modify/cmake/AddEventLibrary.cmake 2021-04-19 16:36:57.982307500 +0800 @@ -153,1 +153,0 @@ - INSTALL_NAME_DIR "${CMAKE_INSTALL_PREFIX}/lib" ================================================ FILE: third_party/patch/openssl/CVE-2021-3711.patch ================================================ diff --git a/crypto/sm2/sm2_crypt.c b/crypto/sm2/sm2_crypt.c index ef505f6441..1188abfc6b 100644 --- a/crypto/sm2/sm2_crypt.c +++ b/crypto/sm2/sm2_crypt.c @@ -61,29 +61,20 @@ static size_t ec_field_size(const EC_GROUP *group) return field_size; } -int sm2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, - size_t *pt_size) +int sm2_plaintext_size(const unsigned char *ct, size_t ct_size, size_t *pt_size) { - const size_t field_size = ec_field_size(EC_KEY_get0_group(key)); - const int md_size = EVP_MD_size(digest); - size_t overhead; + struct SM2_Ciphertext_st *sm2_ctext = NULL; - if (md_size < 0) { - SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_DIGEST); - return 0; - } - if (field_size == 0) { - SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_FIELD); - return 0; - } + sm2_ctext = d2i_SM2_Ciphertext(NULL, &ct, ct_size); - overhead = 10 + 2 * field_size + (size_t)md_size; - if (msg_len <= overhead) { + if (sm2_ctext == NULL) { SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_ENCODING); return 0; } - *pt_size = msg_len - overhead; + *pt_size = sm2_ctext->C2->length; + SM2_Ciphertext_free(sm2_ctext); + return 1; } diff --git a/crypto/sm2/sm2_pmeth.c b/crypto/sm2/sm2_pmeth.c index b42a14c32f..27025fbf3a 100644 --- a/crypto/sm2/sm2_pmeth.c +++ b/crypto/sm2/sm2_pmeth.c @@ -151,7 +151,7 @@ static int pkey_sm2_decrypt(EVP_PKEY_CTX *ctx, const EVP_MD *md = (dctx->md == NULL) ? EVP_sm3() : dctx->md; if (out == NULL) { - if (!sm2_plaintext_size(ec, md, inlen, outlen)) + if (!sm2_plaintext_size(in, inlen, outlen)) return -1; else return 1; diff --git a/include/crypto/sm2.h b/include/crypto/sm2.h index 76ee80baff..50851a83ce 100644 --- a/include/crypto/sm2.h +++ b/include/crypto/sm2.h @@ -60,8 +60,7 @@ int sm2_verify(const unsigned char *dgst, int dgstlen, int sm2_ciphertext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, size_t *ct_size); -int sm2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, - size_t *pt_size); +int sm2_plaintext_size(const unsigned char *ct, size_t ct_size, size_t *pt_size); int sm2_encrypt(const EC_KEY *key, const EVP_MD *digest, diff --git a/test/sm2_internal_test.c b/test/sm2_internal_test.c index 2bb73947ff..41827bb82f 100644 --- a/test/sm2_internal_test.c +++ b/test/sm2_internal_test.c @@ -185,7 +185,7 @@ static int test_sm2_crypt(const EC_GROUP *group, if (!TEST_mem_eq(ctext, ctext_len, expected, ctext_len)) goto done; - if (!TEST_true(sm2_plaintext_size(key, digest, ctext_len, &ptext_len)) + if (!TEST_true(sm2_plaintext_size(ctext, ctext_len, &ptext_len)) || !TEST_int_eq(ptext_len, msg_len)) goto done; ================================================ FILE: third_party/patch/openssl/CVE-2021-3712.patch ================================================ diff --git a/crypto/ec/ec_asn1.c b/crypto/ec/ec_asn1.c index 7b7c75ce84..e497a25909 100644 --- a/crypto/ec/ec_asn1.c +++ b/crypto/ec/ec_asn1.c @@ -761,7 +761,10 @@ EC_GROUP *EC_GROUP_new_from_ecparameters(const ECPARAMETERS *params) ret->seed_len = params->curve->seed->length; } - if (!params->order || !params->base || !params->base->data) { + if (params->order == NULL + || params->base == NULL + || params->base->data == NULL + || params->base->length == 0) { ECerr(EC_F_EC_GROUP_NEW_FROM_ECPARAMETERS, EC_R_ASN1_ERROR); goto err; } ================================================ FILE: third_party/patch/openssl/CVE-2021-4160.patch ================================================ diff --git a/crypto/bn/asm/mips.pl b/crypto/bn/asm/mips.pl index 95cb227dc5..91b7aac6e7 100644 --- a/crypto/bn/asm/mips.pl +++ b/crypto/bn/asm/mips.pl @@ -1986,6 +1986,8 @@ $code.=<<___; sltu $at,$c_2,$t_1 $ADDU $c_3,$t_2,$at $ST $c_2,$BNSZ($a0) + sltu $at,$c_3,$t_2 + $ADDU $c_1,$at mflo ($t_1,$a_2,$a_0) mfhi ($t_2,$a_2,$a_0) ___ @@ -2196,6 +2198,8 @@ $code.=<<___; sltu $at,$c_2,$t_1 $ADDU $c_3,$t_2,$at $ST $c_2,$BNSZ($a0) + sltu $at,$c_3,$t_2 + $ADDU $c_1,$at mflo ($t_1,$a_2,$a_0) mfhi ($t_2,$a_2,$a_0) ___ diff --git a/test/bntest.c b/test/bntest.c index 87e5c4065b..fa9fc07cef 100644 --- a/test/bntest.c +++ b/test/bntest.c @@ -630,6 +630,51 @@ static int test_modexp_mont5(void) if (!TEST_BN_eq(c, d)) goto err; + /* + * Regression test for overflow bug in bn_sqr_comba4/8 for + * mips-linux-gnu and mipsel-linux-gnu 32bit targets. + */ + { + static const char *ehex[] = { + "95564994a96c45954227b845a1e99cb939d5a1da99ee91acc962396ae999a9ee", + "38603790448f2f7694c242a875f0cad0aae658eba085f312d2febbbd128dd2b5", + "8f7d1149f03724215d704344d0d62c587ae3c5939cba4b9b5f3dc5e8e911ef9a", + "5ce1a5a749a4989d0d8368f6e1f8cdf3a362a6c97fb02047ff152b480a4ad985", + "2d45efdf0770542992afca6a0590d52930434bba96017afbc9f99e112950a8b1", + "a359473ec376f329bdae6a19f503be6d4be7393c4e43468831234e27e3838680", + "b949390d2e416a3f9759e5349ab4c253f6f29f819a6fe4cbfd27ada34903300e", + "da021f62839f5878a36f1bc3085375b00fd5fa3e68d316c0fdace87a97558465", + NULL}; + static const char *phex[] = { + "f95dc0f980fbd22e90caa5a387cc4a369f3f830d50dd321c40db8c09a7e1a241", + "a536e096622d3280c0c1ba849c1f4a79bf490f60006d081e8cf69960189f0d31", + "2cd9e17073a3fba7881b21474a13b334116cb2f5dbf3189a6de3515d0840f053", + "c776d3982d391b6d04d642dda5cc6d1640174c09875addb70595658f89efb439", + "dc6fbd55f903aadd307982d3f659207f265e1ec6271b274521b7a5e28e8fd7a5", + "5df089292820477802a43cf5b6b94e999e8c9944ddebb0d0e95a60f88cb7e813", + "ba110d20e1024774107dd02949031864923b3cb8c3f7250d6d1287b0a40db6a4", + "7bd5a469518eb65aa207ddc47d8c6e5fc8e0c105be8fc1d4b57b2e27540471d5", + NULL}; + static const char *mhex[] = { + "fef15d5ce4625f1bccfbba49fc8439c72bf8202af039a2259678941b60bb4a8f", + "2987e965d58fd8cf86a856674d519763d0e1211cc9f8596971050d56d9b35db3", + "785866cfbca17cfdbed6060be3629d894f924a89fdc1efc624f80d41a22f1900", + "9503fcc3824ef62ccb9208430c26f2d8ceb2c63488ec4c07437aa4c96c43dd8b", + "9289ed00a712ff66ee195dc71f5e4ead02172b63c543d69baf495f5fd63ba7bc", + "c633bd309c016e37736da92129d0b053d4ab28d21ad7d8b6fab2a8bbdc8ee647", + "d2fbcf2cf426cf892e6f5639e0252993965dfb73ccd277407014ea784aaa280c", + "b7b03972bc8b0baa72360bdb44b82415b86b2f260f877791cd33ba8f2d65229b", + NULL}; + + if (!TEST_true(parse_bigBN(&e, ehex)) + || !TEST_true(parse_bigBN(&p, phex)) + || !TEST_true(parse_bigBN(&m, mhex)) + || !TEST_true(BN_mod_exp_mont_consttime(d, e, p, m, ctx, NULL)) + || !TEST_true(BN_mod_exp_simple(a, e, p, m, ctx)) + || !TEST_BN_eq(a, d)) + goto err; + } + /* Zero input */ if (!TEST_true(BN_bntest_rand(p, 1024, 0, 0))) goto err; ================================================ FILE: third_party/patch/openssl/CVE-2022-0778.patch ================================================ diff --git a/crypto/bn/bn_sqrt.c b/crypto/bn/bn_sqrt.c index 1723d5ded5..53b0f55985 100644 --- a/crypto/bn/bn_sqrt.c +++ b/crypto/bn/bn_sqrt.c @@ -14,7 +14,8 @@ BIGNUM *BN_mod_sqrt(BIGNUM *in, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx) /* * Returns 'ret' such that ret^2 == a (mod p), using the Tonelli/Shanks * algorithm (cf. Henri Cohen, "A Course in Algebraic Computational Number - * Theory", algorithm 1.5.1). 'p' must be prime! + * Theory", algorithm 1.5.1). 'p' must be prime, otherwise an error or + * an incorrect "result" will be returned. */ { BIGNUM *ret = in; @@ -301,18 +302,23 @@ BIGNUM *BN_mod_sqrt(BIGNUM *in, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx) goto vrfy; } - /* find smallest i such that b^(2^i) = 1 */ - i = 1; - if (!BN_mod_sqr(t, b, p, ctx)) - goto end; - while (!BN_is_one(t)) { - i++; - if (i == e) { - BNerr(BN_F_BN_MOD_SQRT, BN_R_NOT_A_SQUARE); - goto end; + /* Find the smallest i, 0 < i < e, such that b^(2^i) = 1. */ + for (i = 1; i < e; i++) { + if (i == 1) { + if (!BN_mod_sqr(t, b, p, ctx)) + goto end; + + } else { + if (!BN_mod_mul(t, t, t, p, ctx)) + goto end; } - if (!BN_mod_mul(t, t, t, p, ctx)) - goto end; + if (BN_is_one(t)) + break; + } + /* If not found, a is not a square or p is not prime. */ + if (i >= e) { + BNerr(BN_F_BN_MOD_SQRT, BN_R_NOT_A_SQUARE); + goto end; } /* t := y^2^(e - i - 1) */ ================================================ FILE: third_party/patch/openssl/CVE-2022-1292.patch ================================================ diff --git a/tools/c_rehash.in b/tools/c_rehash.in index fa7c6c9fef..83c1cc80e0 100644 --- a/tools/c_rehash.in +++ b/tools/c_rehash.in @@ -152,6 +152,23 @@ sub check_file { return ($is_cert, $is_crl); } +sub compute_hash { + my $fh; + if ( $^O eq "VMS" ) { + # VMS uses the open through shell + # The file names are safe there and list form is unsupported + if (!open($fh, "-|", join(' ', @_))) { + print STDERR "Cannot compute hash on '$fname'\n"; + return; + } + } else { + if (!open($fh, "-|", @_)) { + print STDERR "Cannot compute hash on '$fname'\n"; + return; + } + } + return (<$fh>, <$fh>); +} # Link a certificate to its subject name hash value, each hash is of # the form . where n is an integer. If the hash value already exists @@ -161,10 +178,12 @@ sub check_file { sub link_hash_cert { my $fname = $_[0]; - $fname =~ s/\"/\\\"/g; - my ($hash, $fprint) = `"$openssl" x509 $x509hash -fingerprint -noout -in "$fname"`; + my ($hash, $fprint) = compute_hash($openssl, "x509", $x509hash, + "-fingerprint", "-noout", + "-in", $fname); chomp $hash; chomp $fprint; + return if !$hash; $fprint =~ s/^.*=//; $fprint =~ tr/://d; my $suffix = 0; @@ -202,10 +221,12 @@ sub link_hash_cert { sub link_hash_crl { my $fname = $_[0]; - $fname =~ s/'/'\\''/g; - my ($hash, $fprint) = `"$openssl" crl $crlhash -fingerprint -noout -in '$fname'`; + my ($hash, $fprint) = compute_hash($openssl, "crl", $crlhash, + "-fingerprint", "-noout", + "-in", $fname); chomp $hash; chomp $fprint; + return if !$hash; $fprint =~ s/^.*=//; $fprint =~ tr/://d; my $suffix = 0; ================================================ FILE: third_party/patch/openssl/CVE-2022-2068.patch ================================================ diff --git a/tools/c_rehash.in b/tools/c_rehash.in index cfd18f5da1..9d2a6f6db7 100644 --- a/tools/c_rehash.in +++ b/tools/c_rehash.in @@ -104,52 +104,78 @@ foreach (@dirlist) { } exit($errorcount); +sub copy_file { + my ($src_fname, $dst_fname) = @_; + + if (open(my $in, "<", $src_fname)) { + if (open(my $out, ">", $dst_fname)) { + print $out $_ while (<$in>); + close $out; + } else { + warn "Cannot open $dst_fname for write, $!"; + } + close $in; + } else { + warn "Cannot open $src_fname for read, $!"; + } +} + sub hash_dir { - my %hashlist; - print "Doing $_[0]\n"; - chdir $_[0]; - opendir(DIR, "."); - my @flist = sort readdir(DIR); - closedir DIR; - if ( $removelinks ) { - # Delete any existing symbolic links - foreach (grep {/^[\da-f]+\.r{0,1}\d+$/} @flist) { - if (-l $_) { - print "unlink $_" if $verbose; - unlink $_ || warn "Can't unlink $_, $!\n"; - } - } - } - FILE: foreach $fname (grep {/\.(pem)|(crt)|(cer)|(crl)$/} @flist) { - # Check to see if certificates and/or CRLs present. - my ($cert, $crl) = check_file($fname); - if (!$cert && !$crl) { - print STDERR "WARNING: $fname does not contain a certificate or CRL: skipping\n"; - next; - } - link_hash_cert($fname) if ($cert); - link_hash_crl($fname) if ($crl); - } + my $dir = shift; + my %hashlist; + + print "Doing $dir\n"; + + if (!chdir $dir) { + print STDERR "WARNING: Cannot chdir to '$dir', $!\n"; + return; + } + + opendir(DIR, ".") || print STDERR "WARNING: Cannot opendir '.', $!\n"; + my @flist = sort readdir(DIR); + closedir DIR; + if ( $removelinks ) { + # Delete any existing symbolic links + foreach (grep {/^[\da-f]+\.r{0,1}\d+$/} @flist) { + if (-l $_) { + print "unlink $_\n" if $verbose; + unlink $_ || warn "Can't unlink $_, $!\n"; + } + } + } + FILE: foreach $fname (grep {/\.(pem)|(crt)|(cer)|(crl)$/} @flist) { + # Check to see if certificates and/or CRLs present. + my ($cert, $crl) = check_file($fname); + if (!$cert && !$crl) { + print STDERR "WARNING: $fname does not contain a certificate or CRL: skipping\n"; + next; + } + link_hash_cert($fname) if ($cert); + link_hash_crl($fname) if ($crl); + } + + chdir $pwd; } sub check_file { - my ($is_cert, $is_crl) = (0,0); - my $fname = $_[0]; - open IN, $fname; - while() { - if (/^-----BEGIN (.*)-----/) { - my $hdr = $1; - if ($hdr =~ /^(X509 |TRUSTED |)CERTIFICATE$/) { - $is_cert = 1; - last if ($is_crl); - } elsif ($hdr eq "X509 CRL") { - $is_crl = 1; - last if ($is_cert); - } - } - } - close IN; - return ($is_cert, $is_crl); + my ($is_cert, $is_crl) = (0,0); + my $fname = $_[0]; + + open(my $in, "<", $fname); + while(<$in>) { + if (/^-----BEGIN (.*)-----/) { + my $hdr = $1; + if ($hdr =~ /^(X509 |TRUSTED |)CERTIFICATE$/) { + $is_cert = 1; + last if ($is_crl); + } elsif ($hdr eq "X509 CRL") { + $is_crl = 1; + last if ($is_cert); + } + } + } + close $in; + return ($is_cert, $is_crl); } sub compute_hash { @@ -177,76 +203,48 @@ sub compute_hash { # certificate fingerprints sub link_hash_cert { - my $fname = $_[0]; - my ($hash, $fprint) = compute_hash($openssl, "x509", $x509hash, - "-fingerprint", "-noout", - "-in", $fname); - chomp $hash; - chomp $fprint; - return if !$hash; - $fprint =~ s/^.*=//; - $fprint =~ tr/://d; - my $suffix = 0; - # Search for an unused hash filename - while(exists $hashlist{"$hash.$suffix"}) { - # Hash matches: if fingerprint matches its a duplicate cert - if ($hashlist{"$hash.$suffix"} eq $fprint) { - print STDERR "WARNING: Skipping duplicate certificate $fname\n"; - return; - } - $suffix++; - } - $hash .= ".$suffix"; - if ($symlink_exists) { - print "link $fname -> $hash\n" if $verbose; - symlink $fname, $hash || warn "Can't symlink, $!"; - } else { - print "copy $fname -> $hash\n" if $verbose; - if (open($in, "<", $fname)) { - if (open($out,">", $hash)) { - print $out $_ while (<$in>); - close $out; - } else { - warn "can't open $hash for write, $!"; - } - close $in; - } else { - warn "can't open $fname for read, $!"; - } - } - $hashlist{$hash} = $fprint; + link_hash($_[0], 'cert'); } # Same as above except for a CRL. CRL links are of the form .r sub link_hash_crl { - my $fname = $_[0]; - my ($hash, $fprint) = compute_hash($openssl, "crl", $crlhash, - "-fingerprint", "-noout", - "-in", $fname); - chomp $hash; - chomp $fprint; - return if !$hash; - $fprint =~ s/^.*=//; - $fprint =~ tr/://d; - my $suffix = 0; - # Search for an unused hash filename - while(exists $hashlist{"$hash.r$suffix"}) { - # Hash matches: if fingerprint matches its a duplicate cert - if ($hashlist{"$hash.r$suffix"} eq $fprint) { - print STDERR "WARNING: Skipping duplicate CRL $fname\n"; - return; - } - $suffix++; - } - $hash .= ".r$suffix"; - if ($symlink_exists) { - print "link $fname -> $hash\n" if $verbose; - symlink $fname, $hash || warn "Can't symlink, $!"; - } else { - print "cp $fname -> $hash\n" if $verbose; - system ("cp", $fname, $hash); - warn "Can't copy, $!" if ($? >> 8) != 0; - } - $hashlist{$hash} = $fprint; + link_hash($_[0], 'crl'); +} + +sub link_hash { + my ($fname, $type) = @_; + my $is_cert = $type eq 'cert'; + + my ($hash, $fprint) = compute_hash($openssl, + $is_cert ? "x509" : "crl", + $is_cert ? $x509hash : $crlhash, + "-fingerprint", "-noout", + "-in", $fname); + chomp $hash; + chomp $fprint; + return if !$hash; + $fprint =~ s/^.*=//; + $fprint =~ tr/://d; + my $suffix = 0; + # Search for an unused hash filename + my $crlmark = $is_cert ? "" : "r"; + while(exists $hashlist{"$hash.$crlmark$suffix"}) { + # Hash matches: if fingerprint matches its a duplicate cert + if ($hashlist{"$hash.$crlmark$suffix"} eq $fprint) { + my $what = $is_cert ? 'certificate' : 'CRL'; + print STDERR "WARNING: Skipping duplicate $what $fname\n"; + return; + } + $suffix++; + } + $hash .= ".$crlmark$suffix"; + if ($symlink_exists) { + print "link $fname -> $hash\n" if $verbose; + symlink $fname, $hash || warn "Can't symlink, $!"; + } else { + print "copy $fname -> $hash\n" if $verbose; + copy_file($fname, $hash); + } + $hashlist{$hash} = $fprint; } ================================================ FILE: third_party/patch/openssl/CVE-2022-2097.patch ================================================ diff --git a/crypto/aes/asm/aesni-x86.pl b/crypto/aes/asm/aesni-x86.pl index fe2b26542a..812758e02e 100644 --- a/crypto/aes/asm/aesni-x86.pl +++ b/crypto/aes/asm/aesni-x86.pl @@ -2027,7 +2027,7 @@ my ($l_,$block,$i1,$i3,$i5) = ($rounds_,$key_,$rounds,$len,$out); &movdqu (&QWP(-16*2,$out,$inp),$inout4); &movdqu (&QWP(-16*1,$out,$inp),$inout5); &cmp ($inp,$len); # done yet? - &jb (&label("grandloop")); + &jbe (&label("grandloop")); &set_label("short"); &add ($len,16*6); @@ -2453,7 +2453,7 @@ my ($l_,$block,$i1,$i3,$i5) = ($rounds_,$key_,$rounds,$len,$out); &pxor ($rndkey1,$inout5); &movdqu (&QWP(-16*1,$out,$inp),$inout5); &cmp ($inp,$len); # done yet? - &jb (&label("grandloop")); + &jbe (&label("grandloop")); &set_label("short"); &add ($len,16*6); ================================================ FILE: third_party/patch/openssl/CVE-2022-4304.patch ================================================ diff --git a/crypto/bn/bn_blind.c b/crypto/bn/bn_blind.c index 76fc7ebcff..6e9d239321 100644 --- a/crypto/bn/bn_blind.c +++ b/crypto/bn/bn_blind.c @@ -13,20 +13,6 @@ #define BN_BLINDING_COUNTER 32 -struct bn_blinding_st { - BIGNUM *A; - BIGNUM *Ai; - BIGNUM *e; - BIGNUM *mod; /* just a reference */ - CRYPTO_THREAD_ID tid; - int counter; - unsigned long flags; - BN_MONT_CTX *m_ctx; - int (*bn_mod_exp) (BIGNUM *r, const BIGNUM *a, const BIGNUM *p, - const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx); - CRYPTO_RWLOCK *lock; -}; - BN_BLINDING *BN_BLINDING_new(const BIGNUM *A, const BIGNUM *Ai, BIGNUM *mod) { BN_BLINDING *ret = NULL; diff --git a/crypto/bn/bn_err.c b/crypto/bn/bn_err.c index dd87c152cf..3dd8d9a568 100644 --- a/crypto/bn/bn_err.c +++ b/crypto/bn/bn_err.c @@ -73,6 +73,8 @@ static const ERR_STRING_DATA BN_str_functs[] = { {ERR_PACK(ERR_LIB_BN, BN_F_BN_SET_WORDS, 0), "bn_set_words"}, {ERR_PACK(ERR_LIB_BN, BN_F_BN_STACK_PUSH, 0), "BN_STACK_push"}, {ERR_PACK(ERR_LIB_BN, BN_F_BN_USUB, 0), "BN_usub"}, + {ERR_PACK(ERR_LIB_BN, BN_F_OSSL_BN_RSA_DO_UNBLIND, 0), + "ossl_bn_rsa_do_unblind"}, {0, NULL} }; diff --git a/crypto/bn/bn_local.h b/crypto/bn/bn_local.h index 62a969b134..4d8cb64675 100644 --- a/crypto/bn/bn_local.h +++ b/crypto/bn/bn_local.h @@ -283,6 +283,20 @@ struct bn_gencb_st { } cb; }; +struct bn_blinding_st { + BIGNUM *A; + BIGNUM *Ai; + BIGNUM *e; + BIGNUM *mod; /* just a reference */ + CRYPTO_THREAD_ID tid; + int counter; + unsigned long flags; + BN_MONT_CTX *m_ctx; + int (*bn_mod_exp) (BIGNUM *r, const BIGNUM *a, const BIGNUM *p, + const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx); + CRYPTO_RWLOCK *lock; +}; + /*- * BN_window_bits_for_exponent_size -- macro for sliding window mod_exp functions * diff --git a/crypto/bn/build.info b/crypto/bn/build.info index b9ed5322fa..c9fe2fdada 100644 --- a/crypto/bn/build.info +++ b/crypto/bn/build.info @@ -5,7 +5,8 @@ SOURCE[../../libcrypto]=\ bn_kron.c bn_sqrt.c bn_gcd.c bn_prime.c bn_err.c bn_sqr.c \ {- $target{bn_asm_src} -} \ bn_recp.c bn_mont.c bn_mpi.c bn_exp2.c bn_gf2m.c bn_nist.c \ - bn_depr.c bn_const.c bn_x931p.c bn_intern.c bn_dh.c bn_srp.c + bn_depr.c bn_const.c bn_x931p.c bn_intern.c bn_dh.c bn_srp.c \ + rsa_sup_mul.c INCLUDE[bn_exp.o]=.. diff --git a/crypto/bn/rsa_sup_mul.c b/crypto/bn/rsa_sup_mul.c new file mode 100644 index 0000000000..acafefd5fe --- /dev/null +++ b/crypto/bn/rsa_sup_mul.c @@ -0,0 +1,614 @@ +#include +#include +#include +#include +#include +#include +#include +#include "internal/numbers.h" +#include "internal/constant_time.h" +#include "bn_local.h" + +# if BN_BYTES == 8 +typedef uint64_t limb_t; +# if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16 +/* nonstandard; implemented by gcc on 64-bit platforms */ +typedef __uint128_t limb2_t; +# define HAVE_LIMB2_T +# endif +# define LIMB_BIT_SIZE 64 +# define LIMB_BYTE_SIZE 8 +# elif BN_BYTES == 4 +typedef uint32_t limb_t; +typedef uint64_t limb2_t; +# define LIMB_BIT_SIZE 32 +# define LIMB_BYTE_SIZE 4 +# define HAVE_LIMB2_T +# else +# error "Not supported" +# endif + +/* + * For multiplication we're using schoolbook multiplication, + * so if we have two numbers, each with 6 "digits" (words) + * the multiplication is calculated as follows: + * A B C D E F + * x I J K L M N + * -------------- + * N*F + * N*E + * N*D + * N*C + * N*B + * N*A + * M*F + * M*E + * M*D + * M*C + * M*B + * M*A + * L*F + * L*E + * L*D + * L*C + * L*B + * L*A + * K*F + * K*E + * K*D + * K*C + * K*B + * K*A + * J*F + * J*E + * J*D + * J*C + * J*B + * J*A + * I*F + * I*E + * I*D + * I*C + * I*B + * + I*A + * ========================== + * N*B N*D N*F + * + N*A N*C N*E + * + M*B M*D M*F + * + M*A M*C M*E + * + L*B L*D L*F + * + L*A L*C L*E + * + K*B K*D K*F + * + K*A K*C K*E + * + J*B J*D J*F + * + J*A J*C J*E + * + I*B I*D I*F + * + I*A I*C I*E + * + * 1+1 1+3 1+5 + * 1+0 1+2 1+4 + * 0+1 0+3 0+5 + * 0+0 0+2 0+4 + * + * 0 1 2 3 4 5 6 + * which requires n^2 multiplications and 2n full length additions + * as we can keep every other result of limb multiplication in two separate + * limbs + */ + +#if defined HAVE_LIMB2_T +static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) +{ + limb2_t t; + /* + * this is idiomatic code to tell compiler to use the native mul + * those three lines will actually compile to single instruction + */ + + t = (limb2_t)a * b; + *hi = t >> LIMB_BIT_SIZE; + *lo = (limb_t)t; +} +#elif (BN_BYTES == 8) && (defined _MSC_VER) +/* https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-170 */ +#pragma intrinsic(_umul128) +static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) +{ + *lo = _umul128(a, b, hi); +} +#else +/* + * if the compiler doesn't have either a 128bit data type nor a "return + * high 64 bits of multiplication" + */ +static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) +{ + limb_t a_low = (limb_t)(uint32_t)a; + limb_t a_hi = a >> 32; + limb_t b_low = (limb_t)(uint32_t)b; + limb_t b_hi = b >> 32; + + limb_t p0 = a_low * b_low; + limb_t p1 = a_low * b_hi; + limb_t p2 = a_hi * b_low; + limb_t p3 = a_hi * b_hi; + + uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32); + + *lo = p0 + (p1 << 32) + (p2 << 32); + *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy; +} +#endif + +/* add two limbs with carry in, return carry out */ +static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry) +{ + limb_t carry1, carry2, t; + /* + * `c = a + b; if (c < a)` is idiomatic code that makes compilers + * use add with carry on assembly level + */ + + *ret = a + carry; + if (*ret < a) + carry1 = 1; + else + carry1 = 0; + + t = *ret; + *ret = t + b; + if (*ret < t) + carry2 = 1; + else + carry2 = 0; + + return carry1 + carry2; +} + +/* + * add two numbers of the same size, return overflow + * + * add a to b, place result in ret; all arrays need to be n limbs long + * return overflow from addition (0 or 1) + */ +static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n) +{ + limb_t c = 0; + ossl_ssize_t i; + + for(i = n - 1; i > -1; i--) + c = _add_limb(&ret[i], a[i], b[i], c); + + return c; +} + +/* + * return number of limbs necessary for temporary values + * when multiplying numbers n limbs large + */ +static ossl_inline size_t mul_limb_numb(size_t n) +{ + return 2 * n * 2; +} + +/* + * multiply two numbers of the same size + * + * multiply a by b, place result in ret; a and b need to be n limbs long + * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs + * long + */ +static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp) +{ + limb_t *r_odd, *r_even; + size_t i, j, k; + + r_odd = tmp; + r_even = &tmp[2 * n]; + + memset(ret, 0, 2 * n * sizeof(limb_t)); + + for (i = 0; i < n; i++) { + for (k = 0; k < i + n + 1; k++) { + r_even[k] = 0; + r_odd[k] = 0; + } + for (j = 0; j < n; j++) { + /* + * place results from even and odd limbs in separate arrays so that + * we don't have to calculate overflow every time we get individual + * limb multiplication result + */ + if (j % 2 == 0) + _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]); + else + _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]); + } + /* + * skip the least significant limbs when adding multiples of + * more significant limbs (they're zero anyway) + */ + add(ret, ret, r_even, n + i + 1); + add(ret, ret, r_odd, n + i + 1); + } +} + +/* modifies the value in place by performing a right shift by one bit */ +static ossl_inline void rshift1(limb_t *val, size_t n) +{ + limb_t shift_in = 0, shift_out = 0; + size_t i; + + for (i = 0; i < n; i++) { + shift_out = val[i] & 1; + val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1); + shift_in = shift_out; + } +} + +/* extend the LSB of flag to all bits of limb */ +static ossl_inline limb_t mk_mask(limb_t flag) +{ + flag |= flag << 1; + flag |= flag << 2; + flag |= flag << 4; + flag |= flag << 8; + flag |= flag << 16; +#if (LIMB_BYTE_SIZE == 8) + flag |= flag << 32; +#endif + return flag; +} + +/* + * copy from either a or b to ret based on flag + * when flag == 0, then copies from b + * when flag == 1, then copies from a + */ +static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n) +{ + /* + * would be more efficient with non volatile mask, but then gcc + * generates code with jumps + */ + volatile limb_t mask; + size_t i; + + mask = mk_mask(flag); + for (i = 0; i < n; i++) { +#if (LIMB_BYTE_SIZE == 8) + ret[i] = constant_time_select_64(mask, a[i], b[i]); +#else + ret[i] = constant_time_select_32(mask, a[i], b[i]); +#endif + } +} + +static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow) +{ + limb_t borrow1, borrow2, t; + /* + * while it doesn't look constant-time, this is idiomatic code + * to tell compilers to use the carry bit from subtraction + */ + + *ret = a - borrow; + if (*ret > a) + borrow1 = 1; + else + borrow1 = 0; + + t = *ret; + *ret = t - b; + if (*ret > t) + borrow2 = 1; + else + borrow2 = 0; + + return borrow1 + borrow2; +} + +/* + * place the result of a - b into ret, return the borrow bit. + * All arrays need to be n limbs long + */ +static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n) +{ + limb_t borrow = 0; + ossl_ssize_t i; + + for (i = n - 1; i > -1; i--) + borrow = _sub_limb(&ret[i], a[i], b[i], borrow); + + return borrow; +} + +/* return the number of limbs necessary to allocate for the mod() tmp operand */ +static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum) +{ + return (anum + modnum) * 3; +} + +/* + * calculate a % mod, place the result in ret + * size of a is defined by anum, size of ret and mod is modnum, + * size of tmp is returned by mod_limb_numb() + */ +static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod, + size_t modnum, limb_t *tmp) +{ + limb_t *atmp, *modtmp, *rettmp; + limb_t res; + size_t i; + + memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE); + + atmp = tmp; + modtmp = &tmp[anum + modnum]; + rettmp = &tmp[(anum + modnum) * 2]; + + for (i = modnum; i 0; i--, rp--) { + v = _mul_add_limb(rp, mod, modnum, rp[modnum - 1] * ni0, tmp2); + v = v + carry + rp[-1]; + carry |= (v != rp[-1]); + carry &= (v <= rp[-1]); + rp[-1] = v; + } + + /* perform the final reduction by mod... */ + carry -= sub(ret, rp, mod, modnum); + + /* ...conditionally */ + cselect(carry, ret, rp, ret, modnum); +} + +/* allocated buffer should be freed afterwards */ +static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs) +{ + int i; + int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; + limb_t *ptr = buf + (limbs - real_limbs); + + for (i = 0; i < real_limbs; i++) + ptr[i] = bn->d[real_limbs - i - 1]; +} + +#if LIMB_BYTE_SIZE == 8 +static ossl_inline uint64_t be64(uint64_t host) +{ + const union { + long one; + char little; + } is_endian = { 1 }; + + if (is_endian.little) { + uint64_t big = 0; + + big |= (host & 0xff00000000000000) >> 56; + big |= (host & 0x00ff000000000000) >> 40; + big |= (host & 0x0000ff0000000000) >> 24; + big |= (host & 0x000000ff00000000) >> 8; + big |= (host & 0x00000000ff000000) << 8; + big |= (host & 0x0000000000ff0000) << 24; + big |= (host & 0x000000000000ff00) << 40; + big |= (host & 0x00000000000000ff) << 56; + return big; + } else { + return host; + } +} + +#else +/* Not all platforms have htobe32(). */ +static ossl_inline uint32_t be32(uint32_t host) +{ + const union { + long one; + char little; + } is_endian = { 1 }; + + if (is_endian.little) { + uint32_t big = 0; + + big |= (host & 0xff000000) >> 24; + big |= (host & 0x00ff0000) >> 8; + big |= (host & 0x0000ff00) << 8; + big |= (host & 0x000000ff) << 24; + return big; + } else { + return host; + } +} +#endif + +/* + * We assume that intermediate, possible_arg2, blinding, and ctx are used + * similar to BN_BLINDING_invert_ex() arguments. + * to_mod is RSA modulus. + * buf and num is the serialization buffer and its length. + * + * Here we use classic/Montgomery multiplication and modulo. After the calculation finished + * we serialize the new structure instead of BIGNUMs taking endianness into account. + */ +int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate, + const BN_BLINDING *blinding, + const BIGNUM *possible_arg2, + const BIGNUM *to_mod, BN_CTX *ctx, + unsigned char *buf, int num) +{ + limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL; + limb_t *l_ret = NULL, *l_tmp = NULL, l_buf; + size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0; + size_t l_tmp_count = 0; + int ret = 0; + size_t i; + unsigned char *tmp; + const BIGNUM *arg1 = intermediate; + const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2; + + l_im_count = (BN_num_bytes(arg1) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; + l_mul_count = (BN_num_bytes(arg2) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; + l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; + + l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count; + l_im = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); + l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); + l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE); + + if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL)) + goto err; + + BN_to_limb(arg1, l_im, l_size); + BN_to_limb(arg2, l_mul, l_size); + BN_to_limb(to_mod, l_mod, l_mod_count); + + l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE); + + if (blinding->m_ctx != NULL) { + l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ? + mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count); + l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); + } else { + l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ? + mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count); + l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); + } + + if ((l_ret == NULL) || (l_tmp == NULL)) + goto err; + + if (blinding->m_ctx != NULL) { + limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); + mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, + blinding->m_ctx->n0[0], l_tmp); + } else { + limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); + mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp); + } + + /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */ + if (num < BN_num_bytes(to_mod)) { + BNerr(BN_F_OSSL_BN_RSA_DO_UNBLIND, ERR_R_PASSED_INVALID_ARGUMENT); + goto err; + } + + memset(buf, 0, num); + tmp = buf + num - BN_num_bytes(to_mod); + for (i = 0; i < l_mod_count; i++) { +#if LIMB_BYTE_SIZE == 8 + l_buf = be64(l_ret[i]); +#else + l_buf = be32(l_ret[i]); +#endif + if (i == 0) { + int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num); + + memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta); + tmp += delta; + } else { + memcpy(tmp, &l_buf, LIMB_BYTE_SIZE); + tmp += LIMB_BYTE_SIZE; + } + } + ret = num; + + err: + OPENSSL_free(l_im); + OPENSSL_free(l_mul); + OPENSSL_free(l_mod); + OPENSSL_free(l_tmp); + OPENSSL_free(l_ret); + + return ret; +} diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt index 9f91a4a811..ba3a46d5b9 100644 --- a/crypto/err/openssl.txt +++ b/crypto/err/openssl.txt @@ -1,4 +1,4 @@ -# Copyright 1999-2021 The OpenSSL Project Authors. All Rights Reserved. +# Copyright 1999-2023 The OpenSSL Project Authors. All Rights Reserved. # # Licensed under the OpenSSL license (the "License"). You may not use # this file except in compliance with the License. You can obtain a copy @@ -232,6 +232,7 @@ BN_F_BN_RSHIFT:146:BN_rshift BN_F_BN_SET_WORDS:144:bn_set_words BN_F_BN_STACK_PUSH:148:BN_STACK_push BN_F_BN_USUB:115:BN_usub +BN_F_OSSL_BN_RSA_DO_UNBLIND:151:ossl_bn_rsa_do_unblind BUF_F_BUF_MEM_GROW:100:BUF_MEM_grow BUF_F_BUF_MEM_GROW_CLEAN:105:BUF_MEM_grow_clean BUF_F_BUF_MEM_NEW:101:BUF_MEM_new diff --git a/crypto/rsa/rsa_ossl.c b/crypto/rsa/rsa_ossl.c index b52a66f6a6..6c3c0cf78d 100644 --- a/crypto/rsa/rsa_ossl.c +++ b/crypto/rsa/rsa_ossl.c @@ -465,11 +465,20 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from, BN_free(d); } - if (blinding) - if (!rsa_blinding_invert(blinding, ret, unblind, ctx)) + if (blinding) { + /* + * ossl_bn_rsa_do_unblind() combines blinding inversion and + * 0-padded BN BE serialization + */ + j = ossl_bn_rsa_do_unblind(ret, blinding, unblind, rsa->n, ctx, + buf, num); + if (j == 0) goto err; - - j = BN_bn2binpad(ret, buf, num); + } else { + j = BN_bn2binpad(ret, buf, num); + if (j < 0) + goto err; + } switch (padding) { case RSA_PKCS1_PADDING: diff --git a/include/crypto/bn.h b/include/crypto/bn.h index 60afda1dad..b5f36fb25a 100644 --- a/include/crypto/bn.h +++ b/include/crypto/bn.h @@ -86,5 +86,10 @@ int bn_lshift_fixed_top(BIGNUM *r, const BIGNUM *a, int n); int bn_rshift_fixed_top(BIGNUM *r, const BIGNUM *a, int n); int bn_div_fixed_top(BIGNUM *dv, BIGNUM *rem, const BIGNUM *m, const BIGNUM *d, BN_CTX *ctx); +int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate, + const BN_BLINDING *blinding, + const BIGNUM *possible_arg2, + const BIGNUM *to_mod, BN_CTX *ctx, + unsigned char *buf, int num); #endif diff --git a/include/openssl/bnerr.h b/include/openssl/bnerr.h index 9f3c7cfaab..a0752cea52 100644 --- a/include/openssl/bnerr.h +++ b/include/openssl/bnerr.h @@ -72,6 +72,7 @@ int ERR_load_BN_strings(void); # define BN_F_BN_SET_WORDS 144 # define BN_F_BN_STACK_PUSH 148 # define BN_F_BN_USUB 115 +# define BN_F_OSSL_BN_RSA_DO_UNBLIND 151 /* * BN reason codes. ================================================ FILE: third_party/patch/openssl/CVE-2022-4450.patch ================================================ diff --git a/crypto/pem/pem_lib.c b/crypto/pem/pem_lib.c index d416d939ea..328c30cdbb 100644 --- a/crypto/pem/pem_lib.c +++ b/crypto/pem/pem_lib.c @@ -957,7 +957,9 @@ int PEM_read_bio_ex(BIO *bp, char **name_out, char **header, *data = pem_malloc(len, flags); if (*header == NULL || *data == NULL) { pem_free(*header, flags, 0); + *header = NULL; pem_free(*data, flags, 0); + *data = NULL; goto end; } BIO_read(headerB, *header, headerlen); ================================================ FILE: third_party/patch/openssl/CVE-2023-0215.patch ================================================ diff --git a/crypto/asn1/bio_ndef.c b/crypto/asn1/bio_ndef.c index 760e4846a4..f8d4b1b9aa 100644 --- a/crypto/asn1/bio_ndef.c +++ b/crypto/asn1/bio_ndef.c @@ -49,12 +49,19 @@ static int ndef_suffix(BIO *b, unsigned char **pbuf, int *plen, void *parg); static int ndef_suffix_free(BIO *b, unsigned char **pbuf, int *plen, void *parg); +/* + * On success, the returned BIO owns the input BIO as part of its BIO chain. + * On failure, NULL is returned and the input BIO is owned by the caller. + * + * Unfortunately cannot constify this due to CMS_stream() and PKCS7_stream() + */ BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) { NDEF_SUPPORT *ndef_aux = NULL; BIO *asn_bio = NULL; const ASN1_AUX *aux = it->funcs; ASN1_STREAM_ARG sarg; + BIO *pop_bio = NULL; if (!aux || !aux->asn1_cb) { ASN1err(ASN1_F_BIO_NEW_NDEF, ASN1_R_STREAMING_NOT_SUPPORTED); @@ -69,21 +76,39 @@ BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) out = BIO_push(asn_bio, out); if (out == NULL) goto err; + pop_bio = asn_bio; - BIO_asn1_set_prefix(asn_bio, ndef_prefix, ndef_prefix_free); - BIO_asn1_set_suffix(asn_bio, ndef_suffix, ndef_suffix_free); + if (BIO_asn1_set_prefix(asn_bio, ndef_prefix, ndef_prefix_free) <= 0 + || BIO_asn1_set_suffix(asn_bio, ndef_suffix, ndef_suffix_free) <= 0 + || BIO_ctrl(asn_bio, BIO_C_SET_EX_ARG, 0, ndef_aux) <= 0) + goto err; /* - * Now let callback prepends any digest, cipher etc BIOs ASN1 structure - * needs. + * Now let the callback prepend any digest, cipher, etc., that the BIO's + * ASN1 structure needs. */ sarg.out = out; sarg.ndef_bio = NULL; sarg.boundary = NULL; - if (aux->asn1_cb(ASN1_OP_STREAM_PRE, &val, it, &sarg) <= 0) + /* + * The asn1_cb(), must not have mutated asn_bio on error, leaving it in the + * middle of some partially built, but not returned BIO chain. + */ + if (aux->asn1_cb(ASN1_OP_STREAM_PRE, &val, it, &sarg) <= 0) { + /* + * ndef_aux is now owned by asn_bio so we must not free it in the err + * clean up block + */ + ndef_aux = NULL; goto err; + } + + /* + * We must not fail now because the callback has prepended additional + * BIOs to the chain + */ ndef_aux->val = val; ndef_aux->it = it; @@ -91,11 +116,11 @@ BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) ndef_aux->boundary = sarg.boundary; ndef_aux->out = out; - BIO_ctrl(asn_bio, BIO_C_SET_EX_ARG, 0, ndef_aux); - return sarg.ndef_bio; err: + /* BIO_pop() is NULL safe */ + (void)BIO_pop(pop_bio); BIO_free(asn_bio); OPENSSL_free(ndef_aux); return NULL; diff --git a/test/recipes/80-test_cms.t b/test/recipes/80-test_cms.t index 5dc6a3aebe..ec11bfc253 100644 --- a/test/recipes/80-test_cms.t +++ b/test/recipes/80-test_cms.t @@ -13,7 +13,7 @@ use warnings; use POSIX; use File::Spec::Functions qw/catfile/; use File::Compare qw/compare_text/; -use OpenSSL::Test qw/:DEFAULT srctop_dir srctop_file/; +use OpenSSL::Test qw/:DEFAULT srctop_dir srctop_file with/; use OpenSSL::Test::Utils; setup("test_cms"); @@ -27,7 +27,7 @@ my $smcont = srctop_file("test", "smcont.txt"); my ($no_des, $no_dh, $no_dsa, $no_ec, $no_ec2m, $no_rc2, $no_zlib) = disabled qw/des dh dsa ec ec2m rc2 zlib/; -plan tests => 6; +plan tests => 7; my @smime_pkcs7_tests = ( @@ -584,3 +584,14 @@ sub check_availability { return ""; } + +# Check that we get the expected failure return code +with({ exit_checker => sub { return shift == 6; } }, + sub { + ok(run(app(['openssl', 'cms', '-encrypt', + '-in', srctop_file("test", "smcont.txt"), + '-stream', '-recip', + srctop_file("test/smime-certs", "badrsa.pem"), + ])), + "Check failure during BIO setup with -stream is handled correctly"); + }); diff --git a/test/smime-certs/badrsa.pem b/test/smime-certs/badrsa.pem new file mode 100644 index 0000000000..f824fc2267 --- /dev/null +++ b/test/smime-certs/badrsa.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIToTV4Z0iuK08vZP20oTh//hC8BDANBgkqhkiG9w0BAQ0FADAtMSswKQYD +VfcDEyJTYW1wbGUgTEFNUFMgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MCAXDTE5MTEyMDA2NTQxOFoY +DzIwNTIwOTI3MDY1NDE4WjAZMRcwFQYDVQQDEw5BbGljZSBMb3ZlbGFjZTCCASIwDQYJKoZIhvcN +AQEBBQADggEPADCCAQoCggEBALT0iehYOBY+TZp/T5K2KNI05Hwr+E3wP6XTvyi6WWyTgBK9LCOw +I2juwdRrjFBmXkk7pWpjXwsA3A5GOtz0FpfgyC7OxsVcF7q4WHWZWleYXFKlQHJD73nQwXP968+A +/3rBX7PhO0DBbZnfitOLPgPEwjTtdg0VQQ6Wz+CRQ/YbHPKaw7aRphZO63dKvIKp4cQVtkWQHi6s +yTjGsgkLcLNau5LZDQUdsGV+SAo3nBdWCRYV+I65x8Kf4hCxqqmjV3d/2NKRu0BXnDe/N+iDz3X0 +zEoj0fqXgq4SWcC0nsG1lyyXt1TL270I6ATKRGJWiQVCCpDtc0NT6vdJ45bCSxgCAwEAAaOBlzCB +lDAMBgNVHRMBAf8EAjAAMB4GA1UdEQQXMBWBE2FsaWNlQHNtaW1lLmV4YW1wbGUwEwYDVR0lBAww +CgYIKwYBBQUHAwQwDwYDVR0PAQH/BAUDAwfAADAdBgNVHQ4EFgQUu/bMsi0dBhIcl64papAQ0yBm +ZnMwHwYDVR0jBBgwFoAUeF8OWnjYa+RUcD2z3ez38fL6wEcwDQYJKoZIhvcNAQENBQADggEBABbW +eonR6TMTckehDKNOabwaCIcekahAIL6l9tTzUX5ew6ufiAPlC6I/zQlmUaU0iSyFDG1NW14kNbFt +5CAokyLhMtE4ASHBIHbiOp/ZSbUBTVYJZB61ot7w1/ol5QECSs08b8zrxIncf+t2DHGuVEy/Qq1d +rBz8d4ay8zpqAE1tUyL5Da6ZiKUfWwZQXSI/JlbjQFzYQqTRDnzHWrg1xPeMTO1P2/cplFaseTiv +yk4cYwOp/W9UAWymOZXF8WcJYCIUXkdcG/nEZxr057KlScrJmFXOoh7Y+8ON4iWYYcAfiNgpUFo/ +j8BAwrKKaFvdlZS9k1Ypb2+UQY75mKJE9Bg= +-----END CERTIFICATE----- ================================================ FILE: third_party/patch/openssl/CVE-2023-0286.patch ================================================ diff --git a/crypto/x509v3/v3_genn.c b/crypto/x509v3/v3_genn.c index 87a5eff47c..e54ddc55c9 100644 --- a/crypto/x509v3/v3_genn.c +++ b/crypto/x509v3/v3_genn.c @@ -98,7 +98,7 @@ int GENERAL_NAME_cmp(GENERAL_NAME *a, GENERAL_NAME *b) return -1; switch (a->type) { case GEN_X400: - result = ASN1_TYPE_cmp(a->d.x400Address, b->d.x400Address); + result = ASN1_STRING_cmp(a->d.x400Address, b->d.x400Address); break; case GEN_EDIPARTY: diff --git a/include/openssl/x509v3.h b/include/openssl/x509v3.h index 90fa3592ce..e61c0f29d4 100644 --- a/include/openssl/x509v3.h +++ b/include/openssl/x509v3.h @@ -136,7 +136,7 @@ typedef struct GENERAL_NAME_st { OTHERNAME *otherName; /* otherName */ ASN1_IA5STRING *rfc822Name; ASN1_IA5STRING *dNSName; - ASN1_TYPE *x400Address; + ASN1_STRING *x400Address; X509_NAME *directoryName; EDIPARTYNAME *ediPartyName; ASN1_IA5STRING *uniformResourceIdentifier; diff --git a/test/v3nametest.c b/test/v3nametest.c index d1852190b8..37819da8fd 100644 --- a/test/v3nametest.c +++ b/test/v3nametest.c @@ -646,6 +646,14 @@ static struct gennamedata { 0xb7, 0x09, 0x02, 0x02 }, 15 + }, { + /* + * Regression test for CVE-2023-0286. + */ + { + 0xa3, 0x00 + }, + 2 } }; ================================================ FILE: third_party/patch/openssl/CVE-2023-0464.patch ================================================ From 879f7080d7e141f415c79eaa3a8ac4a3dad0348b Mon Sep 17 00:00:00 2001 From: Pauli Date: Wed, 8 Mar 2023 15:28:20 +1100 Subject: [PATCH] x509: excessive resource use verifying policy constraints A security vulnerability has been identified in all supported versions of OpenSSL related to the verification of X.509 certificate chains that include policy constraints. Attackers may be able to exploit this vulnerability by creating a malicious certificate chain that triggers exponential use of computational resources, leading to a denial-of-service (DoS) attack on affected systems. Fixes CVE-2023-0464 Reviewed-by: Tomas Mraz Reviewed-by: Shane Lontis (Merged from https://github.com/openssl/openssl/pull/20569) --- crypto/x509v3/pcy_local.h | 8 +++++++- crypto/x509v3/pcy_node.c | 12 +++++++++--- crypto/x509v3/pcy_tree.c | 37 +++++++++++++++++++++++++++---------- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/crypto/x509v3/pcy_local.h b/crypto/x509v3/pcy_local.h index 5daf78de45..344aa06765 100644 --- a/crypto/x509v3/pcy_local.h +++ b/crypto/x509v3/pcy_local.h @@ -111,6 +111,11 @@ struct X509_POLICY_LEVEL_st { }; struct X509_POLICY_TREE_st { + /* The number of nodes in the tree */ + size_t node_count; + /* The maximum number of nodes in the tree */ + size_t node_maximum; + /* This is the tree 'level' data */ X509_POLICY_LEVEL *levels; int nlevel; @@ -159,7 +164,8 @@ X509_POLICY_NODE *tree_find_sk(STACK_OF(X509_POLICY_NODE) *sk, X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, X509_POLICY_DATA *data, X509_POLICY_NODE *parent, - X509_POLICY_TREE *tree); + X509_POLICY_TREE *tree, + int extra_data); void policy_node_free(X509_POLICY_NODE *node); int policy_node_match(const X509_POLICY_LEVEL *lvl, const X509_POLICY_NODE *node, const ASN1_OBJECT *oid); diff --git a/crypto/x509v3/pcy_node.c b/crypto/x509v3/pcy_node.c index e2d7b15322..d574fb9d66 100644 --- a/crypto/x509v3/pcy_node.c +++ b/crypto/x509v3/pcy_node.c @@ -59,10 +59,15 @@ X509_POLICY_NODE *level_find_node(const X509_POLICY_LEVEL *level, X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, X509_POLICY_DATA *data, X509_POLICY_NODE *parent, - X509_POLICY_TREE *tree) + X509_POLICY_TREE *tree, + int extra_data) { X509_POLICY_NODE *node; + /* Verify that the tree isn't too large. This mitigates CVE-2023-0464 */ + if (tree->node_maximum > 0 && tree->node_count >= tree->node_maximum) + return NULL; + node = OPENSSL_zalloc(sizeof(*node)); if (node == NULL) { X509V3err(X509V3_F_LEVEL_ADD_NODE, ERR_R_MALLOC_FAILURE); @@ -70,7 +75,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, } node->data = data; node->parent = parent; - if (level) { + if (level != NULL) { if (OBJ_obj2nid(data->valid_policy) == NID_any_policy) { if (level->anyPolicy) goto node_error; @@ -90,7 +95,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, } } - if (tree) { + if (extra_data) { if (tree->extra_data == NULL) tree->extra_data = sk_X509_POLICY_DATA_new_null(); if (tree->extra_data == NULL){ @@ -103,6 +108,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, } } + tree->node_count++; if (parent) parent->nchild++; diff --git a/crypto/x509v3/pcy_tree.c b/crypto/x509v3/pcy_tree.c index 6e8322cbc5..6c7fd35405 100644 --- a/crypto/x509v3/pcy_tree.c +++ b/crypto/x509v3/pcy_tree.c @@ -13,6 +13,18 @@ #include "pcy_local.h" +/* + * If the maximum number of nodes in the policy tree isn't defined, set it to + * a generous default of 1000 nodes. + * + * Defining this to be zero means unlimited policy tree growth which opens the + * door on CVE-2023-0464. + */ + +#ifndef OPENSSL_POLICY_TREE_NODES_MAX +# define OPENSSL_POLICY_TREE_NODES_MAX 1000 +#endif + /* * Enable this to print out the complete policy tree at various point during * evaluation. @@ -168,6 +180,9 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, return X509_PCY_TREE_INTERNAL; } + /* Limit the growth of the tree to mitigate CVE-2023-0464 */ + tree->node_maximum = OPENSSL_POLICY_TREE_NODES_MAX; + /* * http://tools.ietf.org/html/rfc5280#section-6.1.2, figure 3. * @@ -184,7 +199,7 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, level = tree->levels; if ((data = policy_data_new(NULL, OBJ_nid2obj(NID_any_policy), 0)) == NULL) goto bad_tree; - if (level_add_node(level, data, NULL, tree) == NULL) { + if (level_add_node(level, data, NULL, tree, 1) == NULL) { policy_data_free(data); goto bad_tree; } @@ -243,7 +258,8 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, * Return value: 1 on success, 0 otherwise */ static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, - X509_POLICY_DATA *data) + X509_POLICY_DATA *data, + X509_POLICY_TREE *tree) { X509_POLICY_LEVEL *last = curr - 1; int i, matched = 0; @@ -253,13 +269,13 @@ static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, X509_POLICY_NODE *node = sk_X509_POLICY_NODE_value(last->nodes, i); if (policy_node_match(last, node, data->valid_policy)) { - if (level_add_node(curr, data, node, NULL) == NULL) + if (level_add_node(curr, data, node, tree, 0) == NULL) return 0; matched = 1; } } if (!matched && last->anyPolicy) { - if (level_add_node(curr, data, last->anyPolicy, NULL) == NULL) + if (level_add_node(curr, data, last->anyPolicy, tree, 0) == NULL) return 0; } return 1; @@ -272,7 +288,8 @@ static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, * Return value: 1 on success, 0 otherwise. */ static int tree_link_nodes(X509_POLICY_LEVEL *curr, - const X509_POLICY_CACHE *cache) + const X509_POLICY_CACHE *cache, + X509_POLICY_TREE *tree) { int i; @@ -280,7 +297,7 @@ static int tree_link_nodes(X509_POLICY_LEVEL *curr, X509_POLICY_DATA *data = sk_X509_POLICY_DATA_value(cache->data, i); /* Look for matching nodes in previous level */ - if (!tree_link_matching_nodes(curr, data)) + if (!tree_link_matching_nodes(curr, data, tree)) return 0; } return 1; @@ -311,7 +328,7 @@ static int tree_add_unmatched(X509_POLICY_LEVEL *curr, /* Curr may not have anyPolicy */ data->qualifier_set = cache->anyPolicy->qualifier_set; data->flags |= POLICY_DATA_FLAG_SHARED_QUALIFIERS; - if (level_add_node(curr, data, node, tree) == NULL) { + if (level_add_node(curr, data, node, tree, 1) == NULL) { policy_data_free(data); return 0; } @@ -373,7 +390,7 @@ static int tree_link_any(X509_POLICY_LEVEL *curr, } /* Finally add link to anyPolicy */ if (last->anyPolicy && - level_add_node(curr, cache->anyPolicy, last->anyPolicy, NULL) == NULL) + level_add_node(curr, cache->anyPolicy, last->anyPolicy, tree, 0) == NULL) return 0; return 1; } @@ -555,7 +572,7 @@ static int tree_calculate_user_set(X509_POLICY_TREE *tree, extra->qualifier_set = anyPolicy->data->qualifier_set; extra->flags = POLICY_DATA_FLAG_SHARED_QUALIFIERS | POLICY_DATA_FLAG_EXTRA_NODE; - node = level_add_node(NULL, extra, anyPolicy->parent, tree); + node = level_add_node(NULL, extra, anyPolicy->parent, tree, 1); } if (!tree->user_policies) { tree->user_policies = sk_X509_POLICY_NODE_new_null(); @@ -582,7 +599,7 @@ static int tree_evaluate(X509_POLICY_TREE *tree) for (i = 1; i < tree->nlevel; i++, curr++) { cache = policy_cache_set(curr->cert); - if (!tree_link_nodes(curr, cache)) + if (!tree_link_nodes(curr, cache, tree)) return X509_PCY_TREE_INTERNAL; if (!(curr->flags & X509_V_FLAG_INHIBIT_ANY) -- 2.34.1 ================================================ FILE: third_party/patch/openssl/CVE-2023-0465.patch ================================================ From b013765abfa80036dc779dd0e50602c57bb3bf95 Mon Sep 17 00:00:00 2001 From: Matt Caswell Date: Tue, 7 Mar 2023 16:52:55 +0000 Subject: [PATCH] Ensure that EXFLAG_INVALID_POLICY is checked even in leaf certs Even though we check the leaf cert to confirm it is valid, we later ignored the invalid flag and did not notice that the leaf cert was bad. Fixes: CVE-2023-0465 Reviewed-by: Hugo Landau Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/20588) --- crypto/x509/x509_vfy.c | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crypto/x509/x509_vfy.c b/crypto/x509/x509_vfy.c index 925fbb5412..1dfe4f9f31 100644 --- a/crypto/x509/x509_vfy.c +++ b/crypto/x509/x509_vfy.c @@ -1649,18 +1649,25 @@ static int check_policy(X509_STORE_CTX *ctx) } /* Invalid or inconsistent extensions */ if (ret == X509_PCY_TREE_INVALID) { - int i; + int i, cbcalled = 0; /* Locate certificates with bad extensions and notify callback. */ - for (i = 1; i < sk_X509_num(ctx->chain); i++) { + for (i = 0; i < sk_X509_num(ctx->chain); i++) { X509 *x = sk_X509_value(ctx->chain, i); if (!(x->ex_flags & EXFLAG_INVALID_POLICY)) continue; + cbcalled = 1; if (!verify_cb_cert(ctx, x, i, X509_V_ERR_INVALID_POLICY_EXTENSION)) return 0; } + if (!cbcalled) { + /* Should not be able to get here */ + X509err(X509_F_CHECK_POLICY, ERR_R_INTERNAL_ERROR); + return 0; + } + /* The callback ignored the error so we return success */ return 1; } if (ret == X509_PCY_TREE_FAILURE) { -- 2.34.1 ================================================ FILE: third_party/patch/openssl/CVE-2023-0466.patch ================================================ diff --git a/doc/man3/X509_VERIFY_PARAM_set_flags.pod b/doc/man3/X509_VERIFY_PARAM_set_flags.pod index f6f304bf7b..aa292f9336 100644 --- a/doc/man3/X509_VERIFY_PARAM_set_flags.pod +++ b/doc/man3/X509_VERIFY_PARAM_set_flags.pod @@ -92,8 +92,9 @@ B. X509_VERIFY_PARAM_set_time() sets the verification time in B to B. Normally the current time is used. -X509_VERIFY_PARAM_add0_policy() enables policy checking (it is disabled -by default) and adds B to the acceptable policy set. +X509_VERIFY_PARAM_add0_policy() adds B to the acceptable policy set. +Contrary to preexisting documentation of this function it does not enable +policy checking. X509_VERIFY_PARAM_set1_policies() enables policy checking (it is disabled by default) and sets the acceptable policy set to B. Any existing @@ -377,6 +378,10 @@ and has no effect. The X509_VERIFY_PARAM_get_hostflags() function was added in OpenSSL 1.1.0i. +The function X509_VERIFY_PARAM_add0_policy() was historically documented as +enabling policy checking however the implementation has never done this. +The documentation was changed to align with the implementation. + =head1 COPYRIGHT Copyright 2009-2020 The OpenSSL Project Authors. All Rights Reserved. ================================================ FILE: third_party/patch/openssl/CVE-2023-2650.patch ================================================ From 9e209944b35cf82368071f160a744b6178f9b098 Mon Sep 17 00:00:00 2001 From: Richard Levitte Date: Fri, 12 May 2023 10:00:13 +0200 Subject: [PATCH] Restrict the size of OBJECT IDENTIFIERs that OBJ_obj2txt will translate OBJ_obj2txt() would translate any size OBJECT IDENTIFIER to canonical numeric text form. For gigantic sub-identifiers, this would take a very long time, the time complexity being O(n^2) where n is the size of that sub-identifier. To mitigate this, a restriction on the size that OBJ_obj2txt() will translate to canonical numeric text form is added, based on RFC 2578 (STD 58), which says this: > 3.5. OBJECT IDENTIFIER values > > An OBJECT IDENTIFIER value is an ordered list of non-negative numbers. > For the SMIv2, each number in the list is referred to as a sub-identifier, > there are at most 128 sub-identifiers in a value, and each sub-identifier > has a maximum value of 2^32-1 (4294967295 decimal). Fixes otc/security#96 Fixes CVE-2023-2650 Reviewed-by: Matt Caswell Reviewed-by: Tomas Mraz --- crypto/objects/obj_dat.c | 19 +++++++++++++++++++ diff --git a/crypto/objects/obj_dat.c b/crypto/objects/obj_dat.c index 7e8de727f3..d699915b20 100644 --- a/crypto/objects/obj_dat.c +++ b/crypto/objects/obj_dat.c @@ -428,6 +428,25 @@ int OBJ_obj2txt(char *buf, int buf_len, const ASN1_OBJECT *a, int no_name) first = 1; bl = NULL; + /* + * RFC 2578 (STD 58) says this about OBJECT IDENTIFIERs: + * + * > 3.5. OBJECT IDENTIFIER values + * > + * > An OBJECT IDENTIFIER value is an ordered list of non-negative + * > numbers. For the SMIv2, each number in the list is referred to as a + * > sub-identifier, there are at most 128 sub-identifiers in a value, + * > and each sub-identifier has a maximum value of 2^32-1 (4294967295 + * > decimal). + * + * So a legitimate OID according to this RFC is at most (32 * 128 / 7), + * i.e. 586 bytes long. + * + * Ref: https://datatracker.ietf.org/doc/html/rfc2578#section-3.5 + */ + if (len > 586) + goto err; + while (len > 0) { l = 0; use_bn = 0; -- 2.34.1 ================================================ FILE: third_party/patch/openssl/CVE-2023-3446.patch ================================================ From 8780a896543a654e757db1b9396383f9d8095528 Mon Sep 17 00:00:00 2001 From: Matt Caswell Date: Thu, 6 Jul 2023 16:36:35 +0100 Subject: [PATCH] Fix DH_check() excessive time with over sized modulus The DH_check() function checks numerous aspects of the key or parameters that have been supplied. Some of those checks use the supplied modulus value even if it is excessively large. There is already a maximum DH modulus size (10,000 bits) over which OpenSSL will not generate or derive keys. DH_check() will however still perform various tests for validity on such a large modulus. We introduce a new maximum (32,768) over which DH_check() will just fail. An application that calls DH_check() and supplies a key or parameters obtained from an untrusted source could be vulnerable to a Denial of Service attack. The function DH_check() is itself called by a number of other OpenSSL functions. An application calling any of those other functions may similarly be affected. The other functions affected by this are DH_check_ex() and EVP_PKEY_param_check(). CVE-2023-3446 Reviewed-by: Paul Dale Reviewed-by: Tom Cosgrove Reviewed-by: Bernd Edlinger Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/21452) --- crypto/dh/dh_check.c | 6 ++++++ crypto/dh/dh_err.c | 3 ++- crypto/err/openssl.txt | 1 + include/openssl/dh.h | 3 +++ include/openssl/dherr.h | 3 ++- 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/crypto/dh/dh_check.c b/crypto/dh/dh_check.c index 4ac169e75c..e5f9dd5030 100644 --- a/crypto/dh/dh_check.c +++ b/crypto/dh/dh_check.c @@ -101,6 +101,12 @@ int DH_check(const DH *dh, int *ret) BN_CTX *ctx = NULL; BIGNUM *t1 = NULL, *t2 = NULL; + /* Don't do any checks at all with an excessively large modulus */ + if (BN_num_bits(dh->p) > OPENSSL_DH_CHECK_MAX_MODULUS_BITS) { + DHerr(DH_F_DH_CHECK, DH_R_MODULUS_TOO_LARGE); + return 0; + } + if (!DH_check_params(dh, ret)) return 0; diff --git a/crypto/dh/dh_err.c b/crypto/dh/dh_err.c index 7285587b4a..92800d3fcc 100644 --- a/crypto/dh/dh_err.c +++ b/crypto/dh/dh_err.c @@ -1,6 +1,6 @@ /* * Generated by util/mkerr.pl DO NOT EDIT - * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved. + * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved. * * Licensed under the OpenSSL license (the "License"). You may not use * this file except in compliance with the License. You can obtain a copy @@ -18,6 +18,7 @@ static const ERR_STRING_DATA DH_str_functs[] = { {ERR_PACK(ERR_LIB_DH, DH_F_DHPARAMS_PRINT_FP, 0), "DHparams_print_fp"}, {ERR_PACK(ERR_LIB_DH, DH_F_DH_BUILTIN_GENPARAMS, 0), "dh_builtin_genparams"}, + {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK, 0), "DH_check"}, {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_EX, 0), "DH_check_ex"}, {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_PARAMS_EX, 0), "DH_check_params_ex"}, {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_PUB_KEY_EX, 0), "DH_check_pub_key_ex"}, diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt index 9f91a4a811..c0a3cd720b 100644 --- a/crypto/err/openssl.txt +++ b/crypto/err/openssl.txt @@ -401,6 +401,7 @@ CT_F_SCT_SET_VERSION:104:SCT_set_version DH_F_COMPUTE_KEY:102:compute_key DH_F_DHPARAMS_PRINT_FP:101:DHparams_print_fp DH_F_DH_BUILTIN_GENPARAMS:106:dh_builtin_genparams +DH_F_DH_CHECK:126:DH_check DH_F_DH_CHECK_EX:121:DH_check_ex DH_F_DH_CHECK_PARAMS_EX:122:DH_check_params_ex DH_F_DH_CHECK_PUB_KEY_EX:123:DH_check_pub_key_ex diff --git a/include/openssl/dh.h b/include/openssl/dh.h index 3527540cdd..892e31559d 100644 --- a/include/openssl/dh.h +++ b/include/openssl/dh.h @@ -29,6 +29,9 @@ extern "C" { # ifndef OPENSSL_DH_MAX_MODULUS_BITS # define OPENSSL_DH_MAX_MODULUS_BITS 10000 # endif +# ifndef OPENSSL_DH_CHECK_MAX_MODULUS_BITS +# define OPENSSL_DH_CHECK_MAX_MODULUS_BITS 32768 +# endif # define OPENSSL_DH_FIPS_MIN_MODULUS_BITS 1024 diff --git a/include/openssl/dherr.h b/include/openssl/dherr.h index 916b3bed0b..528c819856 100644 --- a/include/openssl/dherr.h +++ b/include/openssl/dherr.h @@ -1,6 +1,6 @@ /* * Generated by util/mkerr.pl DO NOT EDIT - * Copyright 1995-2019 The OpenSSL Project Authors. All Rights Reserved. + * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved. * * Licensed under the OpenSSL license (the "License"). You may not use * this file except in compliance with the License. You can obtain a copy @@ -30,6 +30,7 @@ int ERR_load_DH_strings(void); # define DH_F_COMPUTE_KEY 102 # define DH_F_DHPARAMS_PRINT_FP 101 # define DH_F_DH_BUILTIN_GENPARAMS 106 +# define DH_F_DH_CHECK 126 # define DH_F_DH_CHECK_EX 121 # define DH_F_DH_CHECK_PARAMS_EX 122 # define DH_F_DH_CHECK_PUB_KEY_EX 123 -- 2.34.1 ================================================ FILE: third_party/patch/openssl/CVE-2023-4807.patch ================================================ From a632d534c73eeb3e3db8c7540d811194ef7c79ff Mon Sep 17 00:00:00 2001 From: Bernd Edlinger Date: Tue, 22 Aug 2023 16:07:30 +0200 Subject: [PATCH] Avoid clobbering non-volatile XMM registers This affects some Poly1305 assembler functions which are only used for certain CPU types. Remove those functions for Windows targets, as a simple interim solution. Fixes #21522 Reviewed-by: Tomas Mraz Reviewed-by: Paul Dale (Merged from https://github.com/openssl/openssl/pull/21808) (cherry picked from commit 7b8e27bc2e02238986d89ef0ece067ec1b48e165) --- crypto/poly1305/asm/poly1305-x86_64.pl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/poly1305/asm/poly1305-x86_64.pl b/crypto/poly1305/asm/poly1305-x86_64.pl index 5f834d8faf..801455c639 100755 --- a/crypto/poly1305/asm/poly1305-x86_64.pl +++ b/crypto/poly1305/asm/poly1305-x86_64.pl @@ -193,7 +193,7 @@ $code.=<<___ if ($avx>1); bt \$`5+32`,%r9 # AVX2? cmovc %rax,%r10 ___ -$code.=<<___ if ($avx>3); +$code.=<<___ if ($avx>3 && !$win64); mov \$`(1<<31|1<<21|1<<16)`,%rax shr \$32,%r9 and %rax,%r9 @@ -2722,7 +2722,7 @@ $code.=<<___; .cfi_endproc .size poly1305_blocks_avx512,.-poly1305_blocks_avx512 ___ -if ($avx>3) { +if ($avx>3 && !$win64) { ######################################################################## # VPMADD52 version using 2^44 radix. # -- 2.34.1 ================================================ FILE: third_party/patch/protobuf/CVE-2021-22570.patch ================================================ diff --git a/src/google/protobuf/descriptor.cc b/src/google/protobuf/descriptor.cc index 9a448ffc8..40510b46c 100644 --- a/src/google/protobuf/descriptor.cc +++ b/src/google/protobuf/descriptor.cc @@ -1090,7 +1090,7 @@ inline void DescriptorPool::Tables::FindAllExtensions( bool DescriptorPool::Tables::AddSymbol(const std::string& full_name, Symbol symbol) { - if (InsertIfNotPresent(&symbols_by_name_, full_name.c_str(), symbol)) { + if (InsertIfNotPresent(&symbols_by_name_, full_name, symbol)) { symbols_after_checkpoint_.push_back(full_name.c_str()); return true; } else { @@ -1106,7 +1106,7 @@ bool FileDescriptorTables::AddAliasUnderParent(const void* parent, } bool DescriptorPool::Tables::AddFile(const FileDescriptor* file) { - if (InsertIfNotPresent(&files_by_name_, file->name().c_str(), file)) { + if (InsertIfNotPresent(&files_by_name_, file->name(), file)) { files_after_checkpoint_.push_back(file->name().c_str()); return true; } else { @@ -2628,6 +2628,8 @@ void Descriptor::DebugString(int depth, std::string* contents, const Descriptor::ReservedRange* range = reserved_range(i); if (range->end == range->start + 1) { strings::SubstituteAndAppend(contents, "$0, ", range->start); + } else if (range->end > FieldDescriptor::kMaxNumber) { + strings::SubstituteAndAppend(contents, "$0 to max, ", range->start); } else { strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start, range->end - 1); @@ -2831,6 +2833,8 @@ void EnumDescriptor::DebugString( const EnumDescriptor::ReservedRange* range = reserved_range(i); if (range->end == range->start) { strings::SubstituteAndAppend(contents, "$0, ", range->start); + } else if (range->end == INT_MAX) { + strings::SubstituteAndAppend(contents, "$0 to max, ", range->start); } else { strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start, range->end); @@ -4022,6 +4026,12 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name, // Use its file as the parent instead. if (parent == nullptr) parent = file_; + if (full_name.find('\0') != std::string::npos) { + AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME, + "\"" + full_name + "\" contains null character."); + return false; + } + if (tables_->AddSymbol(full_name, symbol)) { if (!file_tables_->AddAliasUnderParent(parent, name, symbol)) { // This is only possible if there was already an error adding something of @@ -4061,6 +4071,11 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name, void DescriptorBuilder::AddPackage(const std::string& name, const Message& proto, const FileDescriptor* file) { + if (name.find('\0') != std::string::npos) { + AddError(name, proto, DescriptorPool::ErrorCollector::NAME, + "\"" + name + "\" contains null character."); + return; + } if (tables_->AddSymbol(name, Symbol(file))) { // Success. Also add parent package, if any. std::string::size_type dot_pos = name.find_last_of('.'); @@ -4374,6 +4389,12 @@ FileDescriptor* DescriptorBuilder::BuildFileImpl( } result->pool_ = pool_; + if (result->name().find('\0') != std::string::npos) { + AddError(result->name(), proto, DescriptorPool::ErrorCollector::NAME, + "\"" + result->name() + "\" contains null character."); + return nullptr; + } + // Add to tables. if (!tables_->AddFile(result)) { AddError(proto.name(), proto, DescriptorPool::ErrorCollector::OTHER, diff --git a/src/google/protobuf/descriptor_unittest.cc b/src/google/protobuf/descriptor_unittest.cc index 6085a122a..56c180aa4 100644 --- a/src/google/protobuf/descriptor_unittest.cc +++ b/src/google/protobuf/descriptor_unittest.cc @@ -3786,6 +3786,45 @@ TEST_F(ValidationErrorTest, InvalidPackageName) { "foo.proto: foo.$: NAME: \"$\" is not a valid identifier.\n"); } +// 'str' is a static C-style string that may contain '\0' +#define STATIC_STR(str) std::string((str), sizeof(str) - 1) + +TEST_F(ValidationErrorTest, NullCharSymbolName) { + BuildFileWithErrors( + "name: \"bar.proto\" " + "package: \"foo\"" + "message_type { " + " name: '\\000\\001\\013.Bar' " + " field { name: \"foo\" number: 9 label:LABEL_OPTIONAL type:TYPE_INT32 " + "} " + "}", + STATIC_STR("bar.proto: foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a " + "valid identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: " + "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: " + "foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a valid " + "identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: " + "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: " + "foo.\0\x1\v.Bar.foo: NAME: \"foo.\0\x1\v.Bar.foo\" contains " + "null character.\nbar.proto: foo.\0\x1\v.Bar: NAME: " + "\"foo.\0\x1\v.Bar\" contains null character.\n")); +} + +TEST_F(ValidationErrorTest, NullCharFileName) { + BuildFileWithErrors( + "name: \"bar\\000\\001\\013.proto\" " + "package: \"outer.foo\"", + STATIC_STR("bar\0\x1\v.proto: bar\0\x1\v.proto: NAME: " + "\"bar\0\x1\v.proto\" contains null character.\n")); +} + +TEST_F(ValidationErrorTest, NullCharPackageName) { + BuildFileWithErrors( + "name: \"bar.proto\" " + "package: \"\\000\\001\\013.\"", + STATIC_STR("bar.proto: \0\x1\v.: NAME: \"\0\x1\v.\" contains null " + "character.\n")); +} + TEST_F(ValidationErrorTest, MissingFileName) { BuildFileWithErrors("", @@ -4001,6 +4040,32 @@ TEST_F(ValidationErrorTest, ReservedFieldsDebugString) { file->DebugString()); } +TEST_F(ValidationErrorTest, DebugStringReservedRangeMax) { + const FileDescriptor* file = BuildFile(strings::Substitute( + "name: \"foo.proto\" " + "enum_type { " + " name: \"Bar\"" + " value { name:\"BAR\" number:1 }" + " reserved_range { start: 5 end: $0 }" + "}" + "message_type {" + " name: \"Foo\"" + " reserved_range { start: 5 end: $1 }" + "}", + std::numeric_limits::max(), FieldDescriptor::kMaxNumber + 1)); + + ASSERT_EQ( + "syntax = \"proto2\";\n\n" + "enum Bar {\n" + " BAR = 1;\n" + " reserved 5 to max;\n" + "}\n\n" + "message Foo {\n" + " reserved 5 to max;\n" + "}\n\n", + file->DebugString()); +} + TEST_F(ValidationErrorTest, EnumReservedFieldError) { BuildFileWithErrors( "name: \"foo.proto\" " ================================================ FILE: third_party/patch/protobuf/CVE-2022-1941.patch ================================================ diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h index 074784b96..aff050a81 100644 --- a/src/google/protobuf/extension_set_inl.h +++ b/src/google/protobuf/extension_set_inl.h @@ -206,16 +206,22 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( const char* ptr, const Msg* containing_type, internal::InternalMetadata* metadata, internal::ParseContext* ctx) { std::string payload; - uint32 type_id = 0; - bool payload_read = false; + + uint32_t type_id; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (!ctx->Done(&ptr)) { uint32 tag = static_cast(*ptr++); if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint64 tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; ExtensionInfo extension; bool was_packed_on_wire; if (!FindExtension(2, type_id, containing_type, ctx, &extension, @@ -241,20 +247,26 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id != 0) { - ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, - containing_type, metadata, ctx); + + if (state == State::kHasType) { + ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, + containing_type, metadata, ctx); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); - type_id = 0; + state = State::kDone; } else { - int32 size = ReadSize(&ptr); + + std::string tmp; + int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - ptr = ctx->ReadString(ptr, size, &payload); + ptr = ctx->ReadString(ptr, size, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; + if (state == State::kNoTag) { + payload = std::move(tmp); + state = State::kHasPayload; + } } } else { ptr = ReadTag(ptr - 1, &tag); diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 16edf2ce3..88fb09169 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -659,9 +659,11 @@ struct WireFormat::MessageSetParser { const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { // Parse a MessageSetItem auto metadata = reflection->MutableInternalMetadata(msg); + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + std::string payload; - uint32 type_id = 0; - bool payload_read = false; + uint32_t type_id = 0; while (!ctx->Done(&ptr)) { // We use 64 bit tags in order to allow typeid's that span the whole // range of 32 bit numbers. @@ -670,8 +672,11 @@ struct WireFormat::MessageSetParser { uint64 tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; const FieldDescriptor* field; if (ctx->data().pool == nullptr) { field = reflection->FindKnownExtensionByNumber(type_id); @@ -698,17 +703,18 @@ struct WireFormat::MessageSetParser { GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } continue; } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id == 0) { - int32 size = ReadSize(&ptr); + + if (state == State::kNoTag) { + int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); ptr = ctx->ReadString(ptr, size, &payload); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; - } else { + state = State::kHasPayload; + } else if (state == State::kHasType) { // We're now parsing the payload const FieldDescriptor* field = nullptr; if (descriptor->IsExtensionNumber(type_id)) { @@ -722,7 +728,12 @@ struct WireFormat::MessageSetParser { ptr = WireFormat::_InternalParseAndMergeField( msg, ptr, ctx, static_cast(type_id) * 8 + 2, reflection, field); - type_id = 0; + state = State::kDone; + } else { + int32_t size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ptr = ctx->Skip(ptr, size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); } } else { // An unknown field in MessageSetItem. diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h index c742fe869..4130bc531 100644 --- a/src/google/protobuf/wire_format_lite.h +++ b/src/google/protobuf/wire_format_lite.h @@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { // we can parse it later. std::string message_data; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (true) { const uint32 tag = input->ReadTagNoLastTag(); if (tag == 0) return false; @@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { case WireFormatLite::kMessageSetTypeIdTag: { uint32 type_id; if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - - if (!message_data.empty()) { + if (state == State::kNoTag) { + last_type_id = type_id; + state = State::kHasType; + } else if (state == State::kHasPayload) { // We saw some message data before the type_id. Have to parse it // now. io::CodedInputStream sub_input( reinterpret_cast(message_data.data()), static_cast(message_data.size())); sub_input.SetRecursionLimit(input->RecursionBudget()); - if (!ms.ParseField(last_type_id, &sub_input)) { + if (!ms.ParseField(type_id, &sub_input)) { return false; } message_data.clear(); + state = State::kDone; } break; } case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { + if (state == State::kHasType) { + // Already saw type_id, so we can parse this directly. + if (!ms.ParseField(last_type_id, input)) { + return false; + } + state = State::kDone; + } else if (state == State::kNoTag) { // We haven't seen a type_id yet. Append this data to message_data. uint32 length; if (!input->ReadVarint32(&length)) return false; @@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { auto ptr = reinterpret_cast(&message_data[0]); ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); if (!input->ReadRaw(ptr, length)) return false; + state = State::kHasPayload; } else { - // Already saw type_id, so we can parse this directly. - if (!ms.ParseField(last_type_id, input)) { - return false; - } + if (!ms.SkipField(tag, input)) return false; } break; ================================================ FILE: third_party/patch/pybind11/pybind11.patch001 ================================================ diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 3bffbb28..4a6a9809 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -21,6 +21,7 @@ # pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard # pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" #elif defined(_MSC_VER) +#include # pragma warning(push) # pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant ================================================ FILE: third_party/patch/zlib/CVE-2018-25032.patch ================================================ diff -Npur zlib-1.2.11/deflate.c zlib-1.2.11-change/deflate.c --- zlib-1.2.11/deflate.c 2017-01-16 01:29:40.000000000 +0800 +++ zlib-1.2.11-change/deflate.c 2022-07-28 04:48:30.310281281 +0800 @@ -252,10 +252,6 @@ int ZEXPORT deflateInit2_(strm, level, m int wrap = 1; static const char my_version[] = ZLIB_VERSION; - ushf *overlay; - /* We overlay pending_buf and d_buf+l_buf. This works since the average - * output size for (length,distance) codes is <= 24 bits. - */ if (version == Z_NULL || version[0] != my_version[0] || stream_size != sizeof(z_stream)) { @@ -326,9 +322,47 @@ int ZEXPORT deflateInit2_(strm, level, m s->lit_bufsize = 1 << (memLevel + 6); /* 16K elements by default */ - overlay = (ushf *) ZALLOC(strm, s->lit_bufsize, sizeof(ush)+2); - s->pending_buf = (uchf *) overlay; - s->pending_buf_size = (ulg)s->lit_bufsize * (sizeof(ush)+2L); + /* We overlay pending_buf and sym_buf. This works since the average size + * for length/distance pairs over any compressed block is assured to be 31 + * bits or less. + * + * Analysis: The longest fixed codes are a length code of 8 bits plus 5 + * extra bits, for lengths 131 to 257. The longest fixed distance codes are + * 5 bits plus 13 extra bits, for distances 16385 to 32768. The longest + * possible fixed-codes length/distance pair is then 31 bits total. + * + * sym_buf starts one-fourth of the way into pending_buf. So there are + * three bytes in sym_buf for every four bytes in pending_buf. Each symbol + * in sym_buf is three bytes -- two for the distance and one for the + * literal/length. As each symbol is consumed, the pointer to the next + * sym_buf value to read moves forward three bytes. From that symbol, up to + * 31 bits are written to pending_buf. The closest the written pending_buf + * bits gets to the next sym_buf symbol to read is just before the last + * code is written. At that time, 31*(n-2) bits have been written, just + * after 24*(n-2) bits have been consumed from sym_buf. sym_buf starts at + * 8*n bits into pending_buf. (Note that the symbol buffer fills when n-1 + * symbols are written.) The closest the writing gets to what is unread is + * then n+14 bits. Here n is lit_bufsize, which is 16384 by default, and + * can range from 128 to 32768. + * + * Therefore, at a minimum, there are 142 bits of space between what is + * written and what is read in the overlain buffers, so the symbols cannot + * be overwritten by the compressed data. That space is actually 139 bits, + * due to the three-bit fixed-code block header. + * + * That covers the case where either Z_FIXED is specified, forcing fixed + * codes, or when the use of fixed codes is chosen, because that choice + * results in a smaller compressed block than dynamic codes. That latter + * condition then assures that the above analysis also covers all dynamic + * blocks. A dynamic-code block will only be chosen to be emitted if it has + * fewer bits than a fixed-code block would for the same set of symbols. + * Therefore its average symbol length is assured to be less than 31. So + * the compressed data for a dynamic block also cannot overwrite the + * symbols from which it is being constructed. + */ + + s->pending_buf = (uchf *) ZALLOC(strm, s->lit_bufsize, 4); + s->pending_buf_size = (ulg)s->lit_bufsize * 4; if (s->window == Z_NULL || s->prev == Z_NULL || s->head == Z_NULL || s->pending_buf == Z_NULL) { @@ -337,8 +371,12 @@ int ZEXPORT deflateInit2_(strm, level, m deflateEnd (strm); return Z_MEM_ERROR; } - s->d_buf = overlay + s->lit_bufsize/sizeof(ush); - s->l_buf = s->pending_buf + (1+sizeof(ush))*s->lit_bufsize; + s->sym_buf = s->pending_buf + s->lit_bufsize; + s->sym_end = (s->lit_bufsize - 1) * 3; + /* We avoid equality with lit_bufsize*3 because of wraparound at 64K + * on 16 bit machines and because stored blocks are restricted to + * 64K-1 bytes. + */ s->level = level; s->strategy = strategy; @@ -549,7 +587,7 @@ int ZEXPORT deflatePrime (strm, bits, va if (deflateStateCheck(strm)) return Z_STREAM_ERROR; s = strm->state; - if ((Bytef *)(s->d_buf) < s->pending_out + ((Buf_size + 7) >> 3)) + if (s->sym_buf < s->pending_out + ((Buf_size + 7) >> 3)) return Z_BUF_ERROR; do { put = Buf_size - s->bi_valid; @@ -1108,7 +1146,6 @@ int ZEXPORT deflateCopy (dest, source) #else deflate_state *ds; deflate_state *ss; - ushf *overlay; if (deflateStateCheck(source) || dest == Z_NULL) { @@ -1128,8 +1165,7 @@ int ZEXPORT deflateCopy (dest, source) ds->window = (Bytef *) ZALLOC(dest, ds->w_size, 2*sizeof(Byte)); ds->prev = (Posf *) ZALLOC(dest, ds->w_size, sizeof(Pos)); ds->head = (Posf *) ZALLOC(dest, ds->hash_size, sizeof(Pos)); - overlay = (ushf *) ZALLOC(dest, ds->lit_bufsize, sizeof(ush)+2); - ds->pending_buf = (uchf *) overlay; + ds->pending_buf = (uchf *) ZALLOC(dest, ds->lit_bufsize, 4); if (ds->window == Z_NULL || ds->prev == Z_NULL || ds->head == Z_NULL || ds->pending_buf == Z_NULL) { @@ -1143,8 +1179,7 @@ int ZEXPORT deflateCopy (dest, source) zmemcpy(ds->pending_buf, ss->pending_buf, (uInt)ds->pending_buf_size); ds->pending_out = ds->pending_buf + (ss->pending_out - ss->pending_buf); - ds->d_buf = overlay + ds->lit_bufsize/sizeof(ush); - ds->l_buf = ds->pending_buf + (1+sizeof(ush))*ds->lit_bufsize; + ds->sym_buf = ds->pending_buf + ds->lit_bufsize; ds->l_desc.dyn_tree = ds->dyn_ltree; ds->d_desc.dyn_tree = ds->dyn_dtree; @@ -1912,7 +1947,7 @@ local block_state deflate_fast(s, flush) FLUSH_BLOCK(s, 1); return finish_done; } - if (s->last_lit) + if (s->sym_next) FLUSH_BLOCK(s, 0); return block_done; } @@ -2043,7 +2078,7 @@ local block_state deflate_slow(s, flush) FLUSH_BLOCK(s, 1); return finish_done; } - if (s->last_lit) + if (s->sym_next) FLUSH_BLOCK(s, 0); return block_done; } @@ -2118,7 +2153,7 @@ local block_state deflate_rle(s, flush) FLUSH_BLOCK(s, 1); return finish_done; } - if (s->last_lit) + if (s->sym_next) FLUSH_BLOCK(s, 0); return block_done; } @@ -2157,7 +2192,7 @@ local block_state deflate_huff(s, flush) FLUSH_BLOCK(s, 1); return finish_done; } - if (s->last_lit) + if (s->sym_next) FLUSH_BLOCK(s, 0); return block_done; } diff -Npur zlib-1.2.11/deflate.h zlib-1.2.11-change/deflate.h --- zlib-1.2.11/deflate.h 2017-01-01 15:37:10.000000000 +0800 +++ zlib-1.2.11-change/deflate.h 2022-07-28 04:42:55.134287681 +0800 @@ -217,7 +217,7 @@ typedef struct internal_state { /* Depth of each subtree used as tie breaker for trees of equal frequency */ - uchf *l_buf; /* buffer for literals or lengths */ + uchf *sym_buf; /* buffer for distances and literals/lengths */ uInt lit_bufsize; /* Size of match buffer for literals/lengths. There are 4 reasons for @@ -239,13 +239,8 @@ typedef struct internal_state { * - I can't count above 4 */ - uInt last_lit; /* running index in l_buf */ - - ushf *d_buf; - /* Buffer for distances. To simplify the code, d_buf and l_buf have - * the same number of elements. To use different lengths, an extra flag - * array would be necessary. - */ + uInt sym_next; /* running index in sym_buf */ + uInt sym_end; /* symbol table full when sym_next reaches this */ ulg opt_len; /* bit length of current block with optimal trees */ ulg static_len; /* bit length of current block with static trees */ @@ -325,20 +320,22 @@ void ZLIB_INTERNAL _tr_stored_block OF(( # define _tr_tally_lit(s, c, flush) \ { uch cc = (c); \ - s->d_buf[s->last_lit] = 0; \ - s->l_buf[s->last_lit++] = cc; \ + s->sym_buf[s->sym_next++] = 0; \ + s->sym_buf[s->sym_next++] = 0; \ + s->sym_buf[s->sym_next++] = cc; \ s->dyn_ltree[cc].Freq++; \ - flush = (s->last_lit == s->lit_bufsize-1); \ + flush = (s->sym_next == s->sym_end); \ } # define _tr_tally_dist(s, distance, length, flush) \ { uch len = (uch)(length); \ ush dist = (ush)(distance); \ - s->d_buf[s->last_lit] = dist; \ - s->l_buf[s->last_lit++] = len; \ + s->sym_buf[s->sym_next++] = dist; \ + s->sym_buf[s->sym_next++] = dist >> 8; \ + s->sym_buf[s->sym_next++] = len; \ dist--; \ s->dyn_ltree[_length_code[len]+LITERALS+1].Freq++; \ s->dyn_dtree[d_code(dist)].Freq++; \ - flush = (s->last_lit == s->lit_bufsize-1); \ + flush = (s->sym_next == s->sym_end); \ } #else # define _tr_tally_lit(s, c, flush) flush = _tr_tally(s, 0, c) diff -Npur zlib-1.2.11/trees.c zlib-1.2.11-change/trees.c --- zlib-1.2.11/trees.c 2017-01-16 01:07:14.000000000 +0800 +++ zlib-1.2.11-change/trees.c 2022-07-28 05:00:04.094268034 +0800 @@ -416,7 +416,7 @@ local void init_block(s) s->dyn_ltree[END_BLOCK].Freq = 1; s->opt_len = s->static_len = 0L; - s->last_lit = s->matches = 0; + s->sym_next = s->matches = 0; } #define SMALLEST 1 @@ -947,7 +947,7 @@ void ZLIB_INTERNAL _tr_flush_block(s, bu Tracev((stderr, "\nopt %lu(%lu) stat %lu(%lu) stored %lu lit %u ", opt_lenb, s->opt_len, static_lenb, s->static_len, stored_len, - s->last_lit)); + s->sym_next / 3)); if (static_lenb <= opt_lenb) opt_lenb = static_lenb; @@ -1016,8 +1016,9 @@ int ZLIB_INTERNAL _tr_tally (s, dist, lc unsigned dist; /* distance of matched string */ unsigned lc; /* match length-MIN_MATCH or unmatched char (if dist==0) */ { - s->d_buf[s->last_lit] = (ush)dist; - s->l_buf[s->last_lit++] = (uch)lc; + s->sym_buf[s->sym_next++] = dist; + s->sym_buf[s->sym_next++] = dist >> 8; + s->sym_buf[s->sym_next++] = lc; if (dist == 0) { /* lc is the unmatched char */ s->dyn_ltree[lc].Freq++; @@ -1032,30 +1033,7 @@ int ZLIB_INTERNAL _tr_tally (s, dist, lc s->dyn_ltree[_length_code[lc]+LITERALS+1].Freq++; s->dyn_dtree[d_code(dist)].Freq++; } - -#ifdef TRUNCATE_BLOCK - /* Try to guess if it is profitable to stop the current block here */ - if ((s->last_lit & 0x1fff) == 0 && s->level > 2) { - /* Compute an upper bound for the compressed length */ - ulg out_length = (ulg)s->last_lit*8L; - ulg in_length = (ulg)((long)s->strstart - s->block_start); - int dcode; - for (dcode = 0; dcode < D_CODES; dcode++) { - out_length += (ulg)s->dyn_dtree[dcode].Freq * - (5L+extra_dbits[dcode]); - } - out_length >>= 3; - Tracev((stderr,"\nlast_lit %u, in %ld, out ~%ld(%ld%%) ", - s->last_lit, in_length, out_length, - 100L - out_length*100L/in_length)); - if (s->matches < s->last_lit/2 && out_length < in_length/2) return 1; - } -#endif - return (s->last_lit == s->lit_bufsize-1); - /* We avoid equality with lit_bufsize because of wraparound at 64K - * on 16 bit machines and because stored blocks are restricted to - * 64K-1 bytes. - */ + return (s->sym_next == s->sym_end); } /* =========================================================================== @@ -1068,13 +1046,14 @@ local void compress_block(s, ltree, dtre { unsigned dist; /* distance of matched string */ int lc; /* match length or unmatched char (if dist == 0) */ - unsigned lx = 0; /* running index in l_buf */ + unsigned sx = 0; /* running index in sym_buf */ unsigned code; /* the code to send */ int extra; /* number of extra bits to send */ - if (s->last_lit != 0) do { - dist = s->d_buf[lx]; - lc = s->l_buf[lx++]; + if (s->sym_next != 0) do { + dist = s->sym_buf[sx++] & 0xff; + dist += (unsigned)(s->sym_buf[sx++] & 0xff) << 8; + lc = s->sym_buf[sx++]; if (dist == 0) { send_code(s, lc, ltree); /* send a literal byte */ Tracecv(isgraph(lc), (stderr," '%c' ", lc)); @@ -1099,11 +1078,10 @@ local void compress_block(s, ltree, dtre } } /* literal or match pair ? */ - /* Check that the overlay between pending_buf and d_buf+l_buf is ok: */ - Assert((uInt)(s->pending) < s->lit_bufsize + 2*lx, - "pendingBuf overflow"); + /* Check that the overlay between pending_buf and sym_buf is ok: */ + Assert(s->pending < s->lit_bufsize + sx, "pendingBuf overflow"); - } while (lx < s->last_lit); + } while (sx < s->sym_next); send_code(s, END_BLOCK, ltree); } ================================================ FILE: third_party/patch/zlib/CVE-2022-37434.patch ================================================ diff -Npur zlib-1.2.11/inflate.c zlib-1.2.11-change/inflate.c --- zlib-1.2.11/inflate.c 2017-01-01 15:37:10.000000000 +0800 +++ zlib-1.2.11-change/inflate.c 2022-08-17 06:25:06.033176873 +0800 @@ -759,8 +759,9 @@ int flush; if (copy > have) copy = have; if (copy) { if (state->head != Z_NULL && - state->head->extra != Z_NULL) { - len = state->head->extra_len - state->length; + state->head->extra != Z_NULL && + (len = state->head->extra_len - state->length) < + state->head->extra_max) { zmemcpy(state->head->extra + len, next, len + copy > state->head->extra_max ? state->head->extra_max - len : copy); ================================================ FILE: third_party/securec/CMakeLists.txt ================================================ SET(CMAKE_BUILD_TYPE "Debug") if (CMAKE_SYSTEM_NAME MATCHES "Windows") SET(CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -fPIC -O0 -Wall -Wno-deprecated-declarations -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -fstack-protector-all") else() SET(CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -fPIC -O0 -Wall -Wno-deprecated-declarations -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -fstack-protector-all -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)='") endif() SET(CMAKE_C_FLAGS_RELEASE "$ENV{CFLAGS} -fPIC -O3 -Wall -Wno-deprecated-declarations -fstack-protector-all") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) #add flags set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -I/usr/local/include -Werror") include_directories(./include) add_subdirectory(src) ================================================ FILE: third_party/securec/include/securec.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 #define __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 #include "securectype.h" #include #ifndef SECUREC_HAVE_ERRNO_H #if SECUREC_IN_KERNEL #define SECUREC_HAVE_ERRNO_H 0 #else #define SECUREC_HAVE_ERRNO_H 1 #endif #endif /* EINVAL ERANGE may defined in errno.h */ #if SECUREC_HAVE_ERRNO_H #include #endif /* define error code */ #if defined(SECUREC_NEED_ERRNO_TYPE) || !defined(__STDC_WANT_LIB_EXT1__) || \ (defined(__STDC_WANT_LIB_EXT1__) && (__STDC_WANT_LIB_EXT1__ == 0)) #ifndef SECUREC_DEFINED_ERRNO_TYPE #define SECUREC_DEFINED_ERRNO_TYPE /* just check whether macrodefinition exists. */ #ifndef errno_t typedef int errno_t; #endif #endif #endif /* success */ #ifndef EOK #define EOK 0 #endif #ifndef EINVAL /* The src buffer is not correct and destination buffer cant not be reset */ #define EINVAL 22 #endif #ifndef EINVAL_AND_RESET /* Once the error is detected, the dest buffer must be reseted! */ #define EINVAL_AND_RESET (22 | 128) #endif #ifndef ERANGE /* The destination buffer is not long enough and destination buffer can not be reset */ #define ERANGE 34 #endif #ifndef ERANGE_AND_RESET /* Once the error is detected, the dest buffer must be reseted! */ #define ERANGE_AND_RESET (34 | 128) #endif #ifndef EOVERLAP_AND_RESET /* Once the buffer overlap is detected, the dest buffer must be reseted! */ #define EOVERLAP_AND_RESET (54 | 128) #endif /* if you need export the function of this library in Win32 dll, use __declspec(dllexport) */ #ifndef SECUREC_API #if defined(SECUREC_DLL_EXPORT) #define SECUREC_API __declspec(dllexport) #elif defined(SECUREC_DLL_IMPORT) #define SECUREC_API __declspec(dllimport) #else /* Standardized function declaration . If a security function is declared in the your code, * it may cause a compilation alarm,Please delete the security function you declared * Adding extern under windows will cause the system to have inline functions to expand, * so do not add the extern in default */ #if defined(_MSC_VER) #define SECUREC_API #else #define SECUREC_API extern #endif #endif #endif #ifdef __cplusplus extern "C" { #endif /* * Description: The GetHwSecureCVersion function get SecureC Version string and version number. * Parameter: verNumber - to store version number * Return: version string */ SECUREC_API const char *GetHwSecureCVersion(unsigned short *verNumber); #if SECUREC_ENABLE_MEMSET /* * Description: The memset_s function copies the value of c (converted to an unsigned char) into each of * the first count characters of the object pointed to by dest. * Parameter: dest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: c - the value to be copied * Parameter: count -copies fisrt count characters of dest * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t memset_s(void *dest, size_t destMax, int c, size_t count); #endif #ifndef SECUREC_ONLY_DECLARE_MEMSET #define SECUREC_ONLY_DECLARE_MEMSET 0 #endif #if SECUREC_ONLY_DECLARE_MEMSET == 0 #if SECUREC_ENABLE_MEMMOVE /* * Description: The memmove_s function copies n characters from the object pointed to by src * into the object pointed to by dest. * Parameter: dest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: src -source address * Parameter: count -copies count wide characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t memmove_s(void *dest, size_t destMax, const void *src, size_t count); #endif #if SECUREC_ENABLE_MEMCPY /* * Description: The memcpy_s function copies n characters from the object pointed to * by src into the object pointed to by dest. * Parameter: dest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: src -source address * Parameter: count -copies count characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count); #endif #if SECUREC_ENABLE_STRCPY /* * Description: The strcpy_s function copies the string pointed to by strSrc (including * the terminating null character) into the array pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) * Parameter: strSrc -source address * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t strcpy_s(char *strDest, size_t destMax, const char *strSrc); #endif #if SECUREC_ENABLE_STRNCPY /* * Description: The strncpy_s function copies not more than n successive characters (not including * the terminating null character) * from the array pointed to by strSrc to the array pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) * Parameter: strSrc -source address * Parameter: count -copies count characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t strncpy_s(char *strDest, size_t destMax, const char *strSrc, size_t count); #endif #if SECUREC_ENABLE_STRCAT /* * Description: The strcat_s function appends a copy of the string pointed to by strSrc (including * the terminating null character) * to the end of the string pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating null wide character) * Parameter: strSrc -source address * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t strcat_s(char *strDest, size_t destMax, const char *strSrc); #endif #if SECUREC_ENABLE_STRNCAT /* * Description: The strncat_s function appends not more than n successive characters (not including * the terminating null character) * from the array pointed to by strSrc to the end of the string pointed to by strDest. * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) * Parameter: strSrc -source address * Parameter: count -copies count characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t strncat_s(char *strDest, size_t destMax, const char *strSrc, size_t count); #endif #if SECUREC_ENABLE_VSPRINTF /* * Description: The vsprintf_s function is equivalent to the vsprintf function except for the Parameter: destMax * and the explicit runtime-constraints violation * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null wide characte) * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1. */ SECUREC_API int vsprintf_s(char *strDest, size_t destMax, const char *format, va_list argList) SECUREC_ATTRIBUTE(3, 0); #endif #if SECUREC_ENABLE_SPRINTF /* * Description: The sprintf_s function is equivalent to the sprintf function except for the Parameter: destMax * and the explicit runtime-constraints violation * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) * Parameter: format - fromat string * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1. */ SECUREC_API int sprintf_s(char *strDest, size_t destMax, const char *format, ...) SECUREC_ATTRIBUTE(3, 4); #endif #if SECUREC_ENABLE_VSNPRINTF /* * Description: The vsnprintf_s function is equivalent to the vsnprintf function except for the Parameter: * destMax/count and the explicit runtime-constraints violation * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) * Parameter: count - do not write more than count bytes to strDest(not including the terminating null byte ('\0')) * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1.Pay special attention to returning -1 when truncation occurs */ SECUREC_API int vsnprintf_s(char *strDest, size_t destMax, size_t count, const char *format, va_list argList) SECUREC_ATTRIBUTE(4, 0); #endif #if SECUREC_ENABLE_SNPRINTF /* * Description: The snprintf_s function is equivalent to the snprintf function except for the Parameter: * destMax/count and the explicit runtime-constraints violation * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) * Parameter: count - do not write more than count bytes to strDest(not including the terminating null byte ('\0')) * Parameter: format - fromat string * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1.Pay special attention to returning -1 when truncation occurs */ SECUREC_API int snprintf_s(char *strDest, size_t destMax, size_t count, const char *format, ...) SECUREC_ATTRIBUTE(4, 5); #endif #if SECUREC_SNPRINTF_TRUNCATED /* * Description: The vsnprintf_truncated_s function is equivalent to the vsnprintf_s function except * no count Parameter: and Return: value * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1.Pay special attention to returning destMax - 1 when truncation occurs */ SECUREC_API int vsnprintf_truncated_s(char *strDest, size_t destMax, const char *format, va_list argList) SECUREC_ATTRIBUTE(3, 0); /* * Description: The snprintf_truncated_s function is equivalent to the snprintf_2 function except * no count Parameter: and Return: value * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) * Parameter: format - fromat string * Return: the number of characters printed(not including the terminating null byte ('\0')), * If an error occurred Return: -1.Pay special attention to returning destMax - 1 when truncation occurs */ SECUREC_API int snprintf_truncated_s(char *strDest, size_t destMax, const char *format, ...) SECUREC_ATTRIBUTE(3, 4); #endif #if SECUREC_ENABLE_SCANF /* * Description: The scanf_s function is equivalent to fscanf_s with the argument stdin * interposed before the arguments to scanf_s * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int scanf_s(const char *format, ...); #endif #if SECUREC_ENABLE_VSCANF /* * Description: The vscanf_s function is equivalent to scanf_s, with the variable argument list replaced by argList * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vscanf_s(const char *format, va_list argList); #endif #if SECUREC_ENABLE_SSCANF /* * Description: The sscanf_s function is equivalent to fscanf_s, except that input is obtained from a * string (specified by the argument buffer) rather than from a stream * Parameter: buffer - read character from buffer * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int sscanf_s(const char *buffer, const char *format, ...); #endif #if SECUREC_ENABLE_VSSCANF /* * Description: The vsscanf_s function is equivalent to sscanf_s, with the variable argument list * replaced by argList * Parameter: buffer - read character from buffer * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vsscanf_s(const char *buffer, const char *format, va_list argList); #endif #if SECUREC_ENABLE_FSCANF /* * Description: The fscanf_s function is equivalent to fscanf except that the c, s, and [ conversion specifiers * apply to a pair of arguments (unless assignment suppression is indicated by a*) * Parameter: stream - stdio file stream * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int fscanf_s(FILE *stream, const char *format, ...); #endif #if SECUREC_ENABLE_VFSCANF /* * Description: The vfscanf_s function is equivalent to fscanf_s, with the variable argument list * replaced by argList * Parameter: stream - stdio file stream * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vfscanf_s(FILE *stream, const char *format, va_list argList); #endif #if SECUREC_ENABLE_STRTOK /* * Description: The strtok_s function parses a string into a sequence of strToken, * replace all characters in strToken string that match to strDelimit set with 0. * On the first call to strtok_s the string to be parsed should be specified in strToken. * In each subsequent call that should parse the same string, strToken should be NULL * Parameter: strToken - the string to be delimited * Parameter: strDelimit -specifies a set of characters that delimit the tokens in the parsed string * Parameter: context -is a pointer to a char * variable that is used internally by strtok_s function * Return: On the first call returns the address of the first non \0 character, otherwise NULL is returned. * In subsequent calls, the strtoken is set to NULL, and the context set is the same as the previous call, * return NULL if the *context string length is equal 0, otherwise return *context. */ SECUREC_API char *strtok_s(char *strToken, const char *strDelimit, char **context); #endif #if SECUREC_ENABLE_GETS && SECUREC_IN_KERNEL == 0 /* * Description: The gets_s function reads at most one less than the number of characters specified * by destMax from the stream pointed to by stdin, into the array pointed to by buffer * Parameter: buffer - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) * Return: buffer if there was no runtime-constraint violation,If an error occurred Return: NULL. */ SECUREC_API char *gets_s(char *buffer, size_t destMax); #endif #if SECUREC_ENABLE_WCHAR_FUNC #if SECUREC_ENABLE_MEMCPY /* * Description: The wmemcpy_s function copies n successive wide characters from the object pointed to * by src into the object pointed to by dest. * Parameter: dest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: src -source address * Parameter: count -copies count wide characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wmemcpy_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count); #endif #if SECUREC_ENABLE_MEMMOVE /* * Description: The wmemmove_s function copies n successive wide characters from the object * pointed to by src into the object pointed to by dest. * Parameter: dest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: src -source address * Parameter: count -copies count wide characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wmemmove_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count); #endif #if SECUREC_ENABLE_STRCPY /* * Description: The wcscpy_s function copies the wide string pointed to by strSrc (including theterminating * null wide character) into the array pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer * Parameter: strSrc -source address * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wcscpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc); #endif #if SECUREC_ENABLE_STRNCPY /* * Description: The wcsncpy_s function copies not more than n successive wide characters (not including the * terminating null wide character) from the array pointed to by strSrc to the array pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) * Parameter: strSrc -source address * Parameter: count -copies count wide characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wcsncpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count); #endif #if SECUREC_ENABLE_STRCAT /* * Description: The wcscat_s function appends a copy of the wide string pointed to by strSrc (including the * terminating null wide character) to the end of the wide string pointed to by strDest * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) * Parameter: strSrc -source address * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wcscat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc); #endif #if SECUREC_ENABLE_STRNCAT /* * Description: The wcsncat_s function appends not more than n successive wide characters (not including the * terminating null wide character) from the array pointed to by strSrc to the end of the wide string pointed to * by strDest. * Parameter: strDest - destination address * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) * Parameter: strSrc -source address * Parameter: count -copies count wide characters from the src * Return: EOK if there was no runtime-constraint violation */ SECUREC_API errno_t wcsncat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count); #endif #if SECUREC_ENABLE_STRTOK /* * Description: The wcstok_s function is the wide-character equivalent of the strtok_s function * Parameter: strToken - the string to be delimited * Parameter: strDelimit -specifies a set of characters that delimit the tokens in the parsed string * Parameter: context -is a pointer to a char * variable that is used internally by strtok_s function * Return: a pointer to the first character of a token, or a null pointer if there is no token * or there is a runtime-constraint violation. */ SECUREC_API wchar_t *wcstok_s(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context); #endif #if SECUREC_ENABLE_VSPRINTF /* * Description: The vswprintf_s function is the wide-character equivalent of the vsprintf_s function * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null ) * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of characters printed(not including the terminating null wide characte), * If an error occurred Return: -1. */ SECUREC_API int vswprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, va_list argList); #endif #if SECUREC_ENABLE_SPRINTF /* * Description: The swprintf_s function is the wide-character equivalent of the sprintf_s function * Parameter: strDest - produce output according to a format ,write to the character string strDest * Parameter: destMax - The maximum length of destination buffer(including the terminating null ) * Parameter: format - fromat string * Return: the number of characters printed(not including the terminating null wide characte), * If an error occurred Return: -1. */ SECUREC_API int swprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, ...); #endif #if SECUREC_ENABLE_FSCANF /* * Description: The fwscanf_s function is the wide-character equivalent of the fscanf_s function * Parameter: stream - stdio file stream * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int fwscanf_s(FILE *stream, const wchar_t *format, ...); #endif #if SECUREC_ENABLE_VFSCANF /* * Description: The vfwscanf_s function is the wide-character equivalent of the vfscanf_s function * Parameter: stream - stdio file stream * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vfwscanf_s(FILE *stream, const wchar_t *format, va_list argList); #endif #if SECUREC_ENABLE_SCANF /* * Description: The wscanf_s function is the wide-character equivalent of the scanf_s function * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int wscanf_s(const wchar_t *format, ...); #endif #if SECUREC_ENABLE_VSCANF /* * Description: The vwscanf_s function is the wide-character equivalent of the vscanf_s function * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vwscanf_s(const wchar_t *format, va_list argList); #endif #if SECUREC_ENABLE_SSCANF /* * Description: The swscanf_s function is the wide-character equivalent of the sscanf_s function * Parameter: buffer - read character from buffer * Parameter: format - fromat string * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int swscanf_s(const wchar_t *buffer, const wchar_t *format, ...); #endif #if SECUREC_ENABLE_VSSCANF /* * Description: The vswscanf_s function is the wide-character equivalent of the vsscanf_s function * Parameter: buffer - read character from buffer * Parameter: format - fromat string * Parameter: argList - instead of a variable number of arguments * Return: the number of input items assigned, If an error occurred Return: -1. */ SECUREC_API int vswscanf_s(const wchar_t *buffer, const wchar_t *format, va_list argList); #endif #endif /* SECUREC_ENABLE_WCHAR_FUNC */ #endif /* those functions are used by macro ,must declare hare , also for without function declaration warning */ extern errno_t strncpy_error(char *strDest, size_t destMax, const char *strSrc, size_t count); extern errno_t strcpy_error(char *strDest, size_t destMax, const char *strSrc); #if SECUREC_WITH_PERFORMANCE_ADDONS /* those functions are used by macro */ extern errno_t memset_sOptAsm(void *dest, size_t destMax, int c, size_t count); extern errno_t memset_sOptTc(void *dest, size_t destMax, int c, size_t count); extern errno_t memcpy_sOptAsm(void *dest, size_t destMax, const void *src, size_t count); extern errno_t memcpy_sOptTc(void *dest, size_t destMax, const void *src, size_t count); /* strcpy_sp is a macro, NOT a function in performance optimization mode. */ #define strcpy_sp(dest, destMax, src) ((__builtin_constant_p((destMax)) && \ __builtin_constant_p((src))) ? \ SECUREC_STRCPY_SM((dest), (destMax), (src)) : \ strcpy_s((dest), (destMax), (src))) /* strncpy_sp is a macro, NOT a function in performance optimization mode. */ #define strncpy_sp(dest, destMax, src, count) ((__builtin_constant_p((count)) && \ __builtin_constant_p((destMax)) && \ __builtin_constant_p((src))) ? \ SECUREC_STRNCPY_SM((dest), (destMax), (src), (count)) : \ strncpy_s((dest), (destMax), (src), (count))) /* strcat_sp is a macro, NOT a function in performance optimization mode. */ #define strcat_sp(dest, destMax, src) ((__builtin_constant_p((destMax)) && \ __builtin_constant_p((src))) ? \ SECUREC_STRCAT_SM((dest), (destMax), (src)) : \ strcat_s((dest), (destMax), (src))) /* strncat_sp is a macro, NOT a function in performance optimization mode. */ #define strncat_sp(dest, destMax, src, count) ((__builtin_constant_p((count)) && \ __builtin_constant_p((destMax)) && \ __builtin_constant_p((src))) ? \ SECUREC_STRNCAT_SM((dest), (destMax), (src), (count)) : \ strncat_s((dest), (destMax), (src), (count))) /* memcpy_sp is a macro, NOT a function in performance optimization mode. */ #define memcpy_sp(dest, destMax, src, count) (__builtin_constant_p((count)) ? \ (SECUREC_MEMCPY_SM((dest), (destMax), (src), (count))) : \ (__builtin_constant_p((destMax)) ? \ (((size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & \ (unsigned long long)(-2)) < SECUREC_MEM_MAX_LEN)) ? \ memcpy_sOptTc((dest), (destMax), (src), (count)) : ERANGE) : \ memcpy_sOptAsm((dest), (destMax), (src), (count)))) /* memset_sp is a macro, NOT a function in performance optimization mode. */ #define memset_sp(dest, destMax, c, count) (__builtin_constant_p((count)) ? \ (SECUREC_MEMSET_SM((dest), (destMax), (c), (count))) : \ (__builtin_constant_p((destMax)) ? \ (((size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & \ (unsigned long long)(-2)) < SECUREC_MEM_MAX_LEN)) ? \ memset_sOptTc((dest), (destMax), (c), (count)) : ERANGE) : \ memset_sOptAsm((dest), (destMax), (c), (count)))) #else #define strcpy_sp strcpy_s #define strncpy_sp strncpy_s #define strcat_sp strcat_s #define strncat_sp strncat_s #define memcpy_sp memcpy_s #define memset_sp memset_s #endif #ifdef __cplusplus } #endif /* __cplusplus */ #endif /* __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 */ ================================================ FILE: third_party/securec/include/securectype.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 #define __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 #ifndef SECUREC_USING_STD_SECURE_LIB #if defined(_MSC_VER) && _MSC_VER >= 1400 #if defined(__STDC_WANT_SECURE_LIB__) && __STDC_WANT_SECURE_LIB__ == 0 /* Security functions have been provided since vs2005, default use of system library functions */ #define SECUREC_USING_STD_SECURE_LIB 0 #else #define SECUREC_USING_STD_SECURE_LIB 1 #endif #else #define SECUREC_USING_STD_SECURE_LIB 0 #endif #endif /* Compatibility with older Secure C versions, shielding VC symbol redefinition warning */ #if defined(_MSC_VER) && _MSC_VER >= 1400 && SECUREC_USING_STD_SECURE_LIB == 0 #ifndef SECUREC_DISABLE_CRT_FUNC #define SECUREC_DISABLE_CRT_FUNC 1 #endif #ifndef SECUREC_DISABLE_CRT_IMP #define SECUREC_DISABLE_CRT_IMP 1 #endif #else /* MSC VER */ #ifndef SECUREC_DISABLE_CRT_FUNC #define SECUREC_DISABLE_CRT_FUNC 0 #endif #ifndef SECUREC_DISABLE_CRT_IMP #define SECUREC_DISABLE_CRT_IMP 0 #endif #endif #if SECUREC_DISABLE_CRT_FUNC #ifdef __STDC_WANT_SECURE_LIB__ #undef __STDC_WANT_SECURE_LIB__ #endif #define __STDC_WANT_SECURE_LIB__ 0 #endif #if SECUREC_DISABLE_CRT_IMP #ifdef _CRTIMP_ALTERNATIVE #undef _CRTIMP_ALTERNATIVE #endif #define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ #endif /* Compile in kernel under macro control */ #ifndef SECUREC_IN_KERNEL #ifdef __KERNEL__ #define SECUREC_IN_KERNEL 1 #else #define SECUREC_IN_KERNEL 0 #endif #endif #if SECUREC_IN_KERNEL #ifndef SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_SCANF_FILE 0 #endif #ifndef SECUREC_ENABLE_WCHAR_FUNC #define SECUREC_ENABLE_WCHAR_FUNC 0 #endif #else /* SECUREC_IN_KERNEL */ #ifndef SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_SCANF_FILE 1 #endif #ifndef SECUREC_ENABLE_WCHAR_FUNC #define SECUREC_ENABLE_WCHAR_FUNC 1 #endif #endif /* Default secure function declaration, default declarations for non-standard functions */ #ifndef SECUREC_SNPRINTF_TRUNCATED #define SECUREC_SNPRINTF_TRUNCATED 1 #endif #if SECUREC_USING_STD_SECURE_LIB #if defined(_MSC_VER) && _MSC_VER >= 1400 /* Declare secure functions that are not available in the vs compiler */ #ifndef SECUREC_ENABLE_MEMSET #define SECUREC_ENABLE_MEMSET 1 #endif /* vs 2005 have vsnprintf_s function */ #ifndef SECUREC_ENABLE_VSNPRINTF #define SECUREC_ENABLE_VSNPRINTF 0 #endif #ifndef SECUREC_ENABLE_SNPRINTF /* vs 2005 have vsnprintf_s function Adapt the snprintf_s of the security function */ #define snprintf_s _snprintf_s #define SECUREC_ENABLE_SNPRINTF 0 #endif /* befor vs 2010 do not have v functions */ #if _MSC_VER <= 1600 || defined(SECUREC_FOR_V_SCANFS) #ifndef SECUREC_ENABLE_VFSCANF #define SECUREC_ENABLE_VFSCANF 1 #endif #ifndef SECUREC_ENABLE_VSCANF #define SECUREC_ENABLE_VSCANF 1 #endif #ifndef SECUREC_ENABLE_VSSCANF #define SECUREC_ENABLE_VSSCANF 1 #endif #endif #else /* _MSC_VER */ #ifndef SECUREC_ENABLE_MEMSET #define SECUREC_ENABLE_MEMSET 0 #endif #ifndef SECUREC_ENABLE_SNPRINTF #define SECUREC_ENABLE_SNPRINTF 0 #endif #ifndef SECUREC_ENABLE_VSNPRINTF #define SECUREC_ENABLE_VSNPRINTF 0 #endif #endif #ifndef SECUREC_ENABLE_MEMMOVE #define SECUREC_ENABLE_MEMMOVE 0 #endif #ifndef SECUREC_ENABLE_MEMCPY #define SECUREC_ENABLE_MEMCPY 0 #endif #ifndef SECUREC_ENABLE_STRCPY #define SECUREC_ENABLE_STRCPY 0 #endif #ifndef SECUREC_ENABLE_STRNCPY #define SECUREC_ENABLE_STRNCPY 0 #endif #ifndef SECUREC_ENABLE_STRCAT #define SECUREC_ENABLE_STRCAT 0 #endif #ifndef SECUREC_ENABLE_STRNCAT #define SECUREC_ENABLE_STRNCAT 0 #endif #ifndef SECUREC_ENABLE_SPRINTF #define SECUREC_ENABLE_SPRINTF 0 #endif #ifndef SECUREC_ENABLE_VSPRINTF #define SECUREC_ENABLE_VSPRINTF 0 #endif #ifndef SECUREC_ENABLE_SSCANF #define SECUREC_ENABLE_SSCANF 0 #endif #ifndef SECUREC_ENABLE_VSSCANF #define SECUREC_ENABLE_VSSCANF 0 #endif #ifndef SECUREC_ENABLE_SCANF #define SECUREC_ENABLE_SCANF 0 #endif #ifndef SECUREC_ENABLE_VSCANF #define SECUREC_ENABLE_VSCANF 0 #endif #ifndef SECUREC_ENABLE_FSCANF #define SECUREC_ENABLE_FSCANF 0 #endif #ifndef SECUREC_ENABLE_VFSCANF #define SECUREC_ENABLE_VFSCANF 0 #endif #ifndef SECUREC_ENABLE_STRTOK #define SECUREC_ENABLE_STRTOK 0 #endif #ifndef SECUREC_ENABLE_GETS #define SECUREC_ENABLE_GETS 0 #endif #else /* SECUREC_USE_STD_SECURE_LIB */ #ifndef SECUREC_ENABLE_MEMSET #define SECUREC_ENABLE_MEMSET 1 #endif #ifndef SECUREC_ENABLE_MEMMOVE #define SECUREC_ENABLE_MEMMOVE 1 #endif #ifndef SECUREC_ENABLE_MEMCPY #define SECUREC_ENABLE_MEMCPY 1 #endif #ifndef SECUREC_ENABLE_STRCPY #define SECUREC_ENABLE_STRCPY 1 #endif #ifndef SECUREC_ENABLE_STRNCPY #define SECUREC_ENABLE_STRNCPY 1 #endif #ifndef SECUREC_ENABLE_STRCAT #define SECUREC_ENABLE_STRCAT 1 #endif #ifndef SECUREC_ENABLE_STRNCAT #define SECUREC_ENABLE_STRNCAT 1 #endif #ifndef SECUREC_ENABLE_SPRINTF #define SECUREC_ENABLE_SPRINTF 1 #endif #ifndef SECUREC_ENABLE_VSPRINTF #define SECUREC_ENABLE_VSPRINTF 1 #endif #ifndef SECUREC_ENABLE_SNPRINTF #define SECUREC_ENABLE_SNPRINTF 1 #endif #ifndef SECUREC_ENABLE_VSNPRINTF #define SECUREC_ENABLE_VSNPRINTF 1 #endif #ifndef SECUREC_ENABLE_SSCANF #define SECUREC_ENABLE_SSCANF 1 #endif #ifndef SECUREC_ENABLE_VSSCANF #define SECUREC_ENABLE_VSSCANF 1 #endif #ifndef SECUREC_ENABLE_SCANF #if SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_SCANF 1 #else #define SECUREC_ENABLE_SCANF 0 #endif #endif #ifndef SECUREC_ENABLE_VSCANF #if SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_VSCANF 1 #else #define SECUREC_ENABLE_VSCANF 0 #endif #endif #ifndef SECUREC_ENABLE_FSCANF #if SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_FSCANF 1 #else #define SECUREC_ENABLE_FSCANF 0 #endif #endif #ifndef SECUREC_ENABLE_VFSCANF #if SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_VFSCANF 1 #else #define SECUREC_ENABLE_VFSCANF 0 #endif #endif #ifndef SECUREC_ENABLE_STRTOK #define SECUREC_ENABLE_STRTOK 1 #endif #ifndef SECUREC_ENABLE_GETS #define SECUREC_ENABLE_GETS 1 #endif #endif /* SECUREC_USE_STD_SECURE_LIB */ #if SECUREC_ENABLE_SCANF_FILE == 0 #if SECUREC_ENABLE_FSCANF #undef SECUREC_ENABLE_FSCANF #define SECUREC_ENABLE_FSCANF 0 #endif #if SECUREC_ENABLE_VFSCANF #undef SECUREC_ENABLE_VFSCANF #define SECUREC_ENABLE_VFSCANF 0 #endif #if SECUREC_ENABLE_SCANF #undef SECUREC_ENABLE_SCANF #define SECUREC_ENABLE_SCANF 0 #endif #if SECUREC_ENABLE_FSCANF #undef SECUREC_ENABLE_FSCANF #define SECUREC_ENABLE_FSCANF 0 #endif #endif #if SECUREC_IN_KERNEL #include #include #else #include #include #include #endif /* If you need high performance, enable the SECUREC_WITH_PERFORMANCE_ADDONS macro, default is enable . * The macro is automatically closed on the windows platform and linux kernel */ #ifndef SECUREC_WITH_PERFORMANCE_ADDONS #if SECUREC_IN_KERNEL #define SECUREC_WITH_PERFORMANCE_ADDONS 0 #else #define SECUREC_WITH_PERFORMANCE_ADDONS 1 #endif #endif /* if enable SECUREC_COMPATIBLE_WIN_FORMAT, the output format will be compatible to Windows. */ #if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) && !defined(SECUREC_COMPATIBLE_LINUX_FORMAT) #if !defined(SECUREC_COMPATIBLE_WIN_FORMAT) #define SECUREC_COMPATIBLE_WIN_FORMAT #endif #endif #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) /* in windows platform, can't use optimized function for there is no __builtin_constant_p like function */ /* If need optimized macro, can define this: define __builtin_constant_p(x) 0 */ #ifdef SECUREC_WITH_PERFORMANCE_ADDONS #undef SECUREC_WITH_PERFORMANCE_ADDONS #define SECUREC_WITH_PERFORMANCE_ADDONS 0 #endif #endif #if defined(__VXWORKS__) || defined(__vxworks) || defined(__VXWORKS) || defined(_VXWORKS_PLATFORM_) || \ defined(SECUREC_VXWORKS_VERSION_5_4) #if !defined(SECUREC_VXWORKS_PLATFORM) #define SECUREC_VXWORKS_PLATFORM #endif #endif /* if enable SECUREC_COMPATIBLE_LINUX_FORMAT, the output format will be compatible to Linux. */ #if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) #if !defined(SECUREC_COMPATIBLE_LINUX_FORMAT) #define SECUREC_COMPATIBLE_LINUX_FORMAT #endif #endif #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT #include #endif /* add the -DSECUREC_SUPPORT_FORMAT_WARNING compiler option to supoort -Wformat. * default does not check the format is that the same data type in the actual code * in the product is different in the original data type definition of VxWorks and Linux. */ #ifndef SECUREC_SUPPORT_FORMAT_WARNING #define SECUREC_SUPPORT_FORMAT_WARNING 0 #endif /* SECUREC_PCLINT for tool do not recognize __attribute__ just for pclint */ #if SECUREC_SUPPORT_FORMAT_WARNING && !defined(SECUREC_PCLINT) #define SECUREC_ATTRIBUTE(x, y) __attribute__((format(printf, (x), (y)))) #else #define SECUREC_ATTRIBUTE(x, y) #endif /* SECUREC_PCLINT for tool do not recognize __builtin_expect, just for pclint */ #if defined(__GNUC__) && \ ((__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ > 3))) && \ !defined(SECUREC_PCLINT) /* This is a built-in function that can be used without a declaration, if you encounter an undeclared compilation alarm, * you can add -DSECUREC_NEED_BUILTIN_EXPECT_DECLARE to complier options */ #if defined(SECUREC_NEED_BUILTIN_EXPECT_DECLARE) long __builtin_expect(long exp, long c); #endif #define SECUREC_LIKELY(x) __builtin_expect(!!(x), 1) #define SECUREC_UNLIKELY(x) __builtin_expect(!!(x), 0) #else #define SECUREC_LIKELY(x) (x) #define SECUREC_UNLIKELY(x) (x) #endif /* define the max length of the string */ #ifndef SECUREC_STRING_MAX_LEN #define SECUREC_STRING_MAX_LEN (0x7fffffffUL) #endif #define SECUREC_WCHAR_STRING_MAX_LEN (SECUREC_STRING_MAX_LEN / sizeof(wchar_t)) /* add SECUREC_MEM_MAX_LEN for memcpy and memmove */ #ifndef SECUREC_MEM_MAX_LEN #define SECUREC_MEM_MAX_LEN (0x7fffffffUL) #endif #define SECUREC_WCHAR_MEM_MAX_LEN (SECUREC_MEM_MAX_LEN / sizeof(wchar_t)) #if SECUREC_STRING_MAX_LEN > 0x7fffffff #error "max string is 2G" #endif #if (defined(__GNUC__) && defined(__SIZEOF_POINTER__)) #if (__SIZEOF_POINTER__ != 4) && (__SIZEOF_POINTER__ != 8) #error "unsupported system" #endif #endif #if defined(_WIN64) || defined(WIN64) || defined(__LP64__) || defined(_LP64) #define SECUREC_ON_64BITS #endif #if (!defined(SECUREC_ON_64BITS) && defined(__GNUC__) && defined(__SIZEOF_POINTER__)) #if __SIZEOF_POINTER__ == 8 #define SECUREC_ON_64BITS #endif #endif #if defined(__SVR4) || defined(__svr4__) #define SECUREC_ON_SOLARIS #endif #if (defined(__hpux) || defined(_AIX) || defined(SECUREC_ON_SOLARIS)) #define SECUREC_ON_UNIX #endif /* codes should run under the macro SECUREC_COMPATIBLE_LINUX_FORMAT in unknow system on default, * and strtold. The function * strtold is referenced first at ISO9899:1999(C99), and some old compilers can * not support these functions. Here provides a macro to open these functions: * SECUREC_SUPPORT_STRTOLD -- if defined, strtold will be used */ #ifndef SECUREC_SUPPORT_STRTOLD #define SECUREC_SUPPORT_STRTOLD 0 #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) #if defined(__USE_ISOC99) || \ (defined(_AIX) && defined(_ISOC99_SOURCE)) || \ (defined(__hpux) && defined(__ia64)) || \ (defined(SECUREC_ON_SOLARIS) && (!defined(_STRICT_STDC) && !defined(__XOPEN_OR_POSIX)) || \ defined(_STDC_C99) || defined(__EXTENSIONS__)) #undef SECUREC_SUPPORT_STRTOLD #define SECUREC_SUPPORT_STRTOLD 1 #endif #endif #if ((defined(SECUREC_WRLINUX_BELOW4) || defined(_WRLINUX_BELOW4_))) #undef SECUREC_SUPPORT_STRTOLD #define SECUREC_SUPPORT_STRTOLD 0 #endif #endif #if SECUREC_WITH_PERFORMANCE_ADDONS #ifndef SECUREC_TWO_MIN #define SECUREC_TWO_MIN(a, b) ((a) < (b) ? (a) : (b)) #endif /* for strncpy_s performance optimization */ #define SECUREC_STRNCPY_SM(dest, destMax, src, count) \ (((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ (SECUREC_TWO_MIN((size_t)(count), strlen(src)) + 1) <= (size_t)(destMax)) ? \ (((size_t)(count) < strlen(src)) ? (memcpy((dest), (src), (count)), *((char *)(dest) + (count)) = '\0', EOK) : \ (memcpy((dest), (src), strlen(src) + 1), EOK)) : (strncpy_error((dest), (destMax), (src), (count)))) #define SECUREC_STRCPY_SM(dest, destMax, src) \ (((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ (strlen(src) + 1) <= (size_t)(destMax)) ? (memcpy((dest), (src), strlen(src) + 1), EOK) : \ (strcpy_error((dest), (destMax), (src)))) /* for strcat_s performance optimization */ #if defined(__GNUC__) #define SECUREC_STRCAT_SM(dest, destMax, src) ({ \ int catRet = EOK; \ if ((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN)) { \ char *catTmpDst = (char *)(dest); \ size_t catRestSize = (destMax); \ while (catRestSize > 0 && *catTmpDst != '\0') { \ ++catTmpDst; \ --catRestSize; \ } \ if (catRestSize == 0) { \ catRet = EINVAL; \ } else if ((strlen(src) + 1) <= catRestSize) { \ memcpy(catTmpDst, (src), strlen(src) + 1); \ catRet = EOK; \ } else { \ catRet = ERANGE; \ } \ if (catRet != EOK) { \ catRet = strcat_s((dest), (destMax), (src)); \ } \ } else { \ catRet = strcat_s((dest), (destMax), (src)); \ } \ catRet; \ }) #else #define SECUREC_STRCAT_SM(dest, destMax, src) strcat_s((dest), (destMax), (src)) #endif /* for strncat_s performance optimization */ #if defined(__GNUC__) #define SECUREC_STRNCAT_SM(dest, destMax, src, count) ({ \ int ncatRet = EOK; \ if ((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ (((unsigned long long)(count) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN)) { \ char *ncatTmpDest = (char *)(dest); \ size_t ncatRestSize = (size_t)(destMax); \ while (ncatRestSize > 0 && *ncatTmpDest != '\0') { \ ++ncatTmpDest; \ --ncatRestSize; \ } \ if (ncatRestSize == 0) { \ ncatRet = EINVAL; \ } else if ((SECUREC_TWO_MIN((count), strlen(src)) + 1) <= ncatRestSize) { \ if ((size_t)(count) < strlen(src)) { \ memcpy(ncatTmpDest, (src), (count)); \ *(ncatTmpDest + (count)) = '\0'; \ } else { \ memcpy(ncatTmpDest, (src), strlen(src) + 1); \ } \ } else { \ ncatRet = ERANGE; \ } \ if (ncatRet != EOK) { \ ncatRet = strncat_s((dest), (destMax), (src), (count)); \ } \ } else { \ ncatRet = strncat_s((dest), (destMax), (src), (count)); \ } \ ncatRet; \ }) #else #define SECUREC_STRNCAT_SM(dest, destMax, src, count) strncat_s((dest), (destMax), (src), (count)) #endif /* SECUREC_MEMCPY_SM do NOT check buffer overlap by default */ #define SECUREC_MEMCPY_SM(dest, destMax, src, count) \ (!(((size_t)(destMax) == 0) || \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ ((size_t)(count) > (size_t)(destMax)) || ((void *)(dest)) == NULL || ((void *)(src) == NULL))? \ (memcpy((dest), (src), (count)), EOK) : \ (memcpy_s((dest), (destMax), (src), (count)))) #define SECUREC_MEMSET_SM(dest, destMax, c, count) \ (!(((size_t)(destMax) == 0) || \ (((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ ((void *)(dest) == NULL) || ((size_t)(count) > (size_t)(destMax))) ? \ (memset((dest), (c), (count)), EOK) : \ (memset_s((dest), (destMax), (c), (count)))) #endif #endif /* __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 */ ================================================ FILE: third_party/securec/src/CMakeLists.txt ================================================ if (CMAKE_SYSTEM_NAME MATCHES "Windows") list(APPEND SECUREC_SRCS "memset_s.c") else() aux_source_directory(. SECUREC_SRCS) endif() add_library(securec STATIC ${SECUREC_SRCS}) ================================================ FILE: third_party/securec/src/fscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The fscanf_s function is equivalent to fscanf except that the c, s, * and [ conversion specifiers apply to a pair of arguments (unless assignment suppression is indicated by a*) * The fscanf function reads data from the current position of stream into * the locations given by argument (if any). Each argument must be a pointer * to a variable of a type that corresponds to a type specifier in format. * format controls the interpretation of the input fields and has the same * form and function as the format argument for scanf. * * * stream Pointer to FILE structure. * format Format control string, see Format Specifications. * ... Optional arguments. * * * ... The convered value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int fscanf_s(FILE *stream, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vfscanf_s(stream, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } ================================================ FILE: third_party/securec/src/fwscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The fwscanf_s function is the wide-character equivalent of the fscanf_s function * The fwscanf_s function reads data from the current position of stream into * the locations given by argument (if any). Each argument must be a pointer * to a variable of a type that corresponds to a type specifier in format. * format controls the interpretation of the input fields and has the same * form and function as the format argument for scanf. * * * stream Pointer to FILE structure. * format Format control string, see Format Specifications. * ... Optional arguments. * * * ... The converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int fwscanf_s(FILE *stream, const wchar_t *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vfwscanf_s(stream, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } ================================================ FILE: third_party/securec/src/gets_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securecutil.h" static void SecTrimCRLF(char *buffer, size_t len) { int i; /* No need to determine whether integer overflow exists */ for (i = (int)(len - 1); i >= 0 && (buffer[i] == '\r' || buffer[i] == '\n'); --i) { buffer[i] = '\0'; } return; } /* * * The gets_s function reads at most one less than the number of characters * specified by destMax from the stream pointed to by stdin, into the array pointed to by buffer * The line consists of all characters up to and including * the first newline character ('\n'). gets_s then replaces the newline * character with a null character ('\0') before returning the line. * If the first character read is the end-of-file character, a null character * is stored at the beginning of buffer and NULL is returned. * * * buffer Storage location for input string. * numberOfElements The size of the buffer. * * * buffer is updated * * * buffer Successful operation * NULL Improper parameter or read fail */ char *gets_s(char *buffer, size_t numberOfElements) { size_t len; #ifdef SECUREC_COMPATIBLE_WIN_FORMAT size_t bufferSize = ((numberOfElements == (size_t)-1) ? SECUREC_STRING_MAX_LEN : numberOfElements); #else size_t bufferSize = numberOfElements; #endif if (buffer == NULL || bufferSize == 0 || bufferSize > SECUREC_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_PARAMTER("gets_s"); return NULL; } if (fgets(buffer, (int)bufferSize, stdin) == NULL) { return NULL; } len = strlen(buffer); if (len > 0 && len < bufferSize) { SecTrimCRLF(buffer, len); } return buffer; } ================================================ FILE: third_party/securec/src/input.inl ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INPUT_INL_5D13A042_DC3F_4ED9_A8D1_882811274C27 #define INPUT_INL_5D13A042_DC3F_4ED9_A8D1_882811274C27 #if SECUREC_IN_KERNEL #include #ifndef EOF #define EOF (-1) #endif #else #if !defined(SECUREC_SYSAPI4VXWORKS) && !defined(SECUREC_CTYPE_MACRO_ADAPT) #include #ifdef SECUREC_FOR_WCHAR #include /* for iswspace */ #endif #endif #endif #define SECUREC_NUM_WIDTH_SHORT 0 #define SECUREC_NUM_WIDTH_INT 1 #define SECUREC_NUM_WIDTH_LONG 2 #define SECUREC_NUM_WIDTH_LONG_LONG 3 /* also long double */ #define SECUREC_BUF_EXT_MUL 2 #define SECUREC_BUFFERED_BLOK_SIZE 1024 #if defined(SECUREC_VXWORKS_PLATFORM) && !defined(va_copy) && !defined(__va_copy) /* the name is the same as system macro. */ #define __va_copy(d, s) do { \ size_t size_of_d = (size_t)sizeof(d); \ size_t size_of_s = (size_t)sizeof(s); \ if (size_of_d != size_of_s) { \ (void)memcpy((d), (s), sizeof(va_list)); \ } else { \ (void)memcpy(&(d), &(s), sizeof(va_list)); \ } \ } SECUREC_WHILE_ZERO #endif #define SECUREC_MULTI_BYTE_MAX_LEN 6 /* Record a flag for each bit */ #define SECUREC_BRACKET_INDEX(x) ((unsigned int)(x) >> 3) #define SECUREC_BRACKET_VALUE(x) ((unsigned char)(1 << ((unsigned int)(x) & 7))) /* Compatibility macro name cannot be modifie */ #ifndef UNALIGNED #if !(defined(_M_IA64)) && !(defined(_M_AMD64)) #define UNALIGNED #else #define UNALIGNED __unaligned #endif #endif #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) /* Max 64bit value is 0xffffffffffffffff */ #define SECUREC_MAX_64BITS_VALUE 18446744073709551615ULL #define SECUREC_MAX_64BITS_VALUE_DIV_TEN 1844674407370955161ULL #define SECUREC_MAX_64BITS_VALUE_CUT_LAST_DIGIT 18446744073709551610ULL #define SECUREC_MIN_64BITS_NEG_VALUE 9223372036854775808ULL #define SECUREC_MAX_64BITS_POS_VALUE 9223372036854775807ULL #define SECUREC_MIN_32BITS_NEG_VALUE 2147483648ULL #define SECUREC_MAX_32BITS_POS_VALUE 2147483647ULL #define SECUREC_MAX_32BITS_VALUE 4294967295ULL #define SECUREC_MAX_32BITS_VALUE_INC 4294967296ULL #define SECUREC_MAX_32BITS_VALUE_DIV_TEN 429496729ULL #define SECUREC_LONG_BIT_NUM ((unsigned int)(sizeof(long) << 3U)) #define SECUREC_LONG_HEX_BEYOND_MAX(number) (((number) >> (SECUREC_LONG_BIT_NUM - 4U)) > 0) #define SECUREC_LONG_OCTAL_BEYOND_MAX(number) (((number) >> (SECUREC_LONG_BIT_NUM - 3U)) > 0) #define SECUREC_QWORD_HEX_BEYOND_MAX(number) (((number) >> (64U - 4U)) > 0) #define SECUREC_QWORD_OCTAL_BEYOND_MAX(number) (((number) >> (64U - 3U)) > 0) #define SECUREC_LP64_BIT_WIDTH 64 #define SECUREC_LP32_BIT_WIDTH 32 #endif #define SECUREC_CHAR(x) (x) #define SECUREC_BRACE '{' /* [ to { */ #ifdef SECUREC_FOR_WCHAR #define SECUREC_SCANF_BRACKET_CONDITION(comChr, ch, table, mask) ((comChr) == SECUREC_BRACE && \ (table) != NULL && \ (((table)[((unsigned int)(int)(ch) & SECUREC_CHAR_MASK) >> 3] ^ (mask)) & \ (1 << ((unsigned int)(int)(ch) & 7)))) #else #define SECUREC_SCANF_BRACKET_CONDITION(comChr, ch, table, mask) ((comChr) == SECUREC_BRACE && \ (((table)[((unsigned char)(ch) & 0xff) >> 3] ^ (mask)) & (1 << ((unsigned char)(ch) & 7)))) #endif #define SECUREC_SCANF_STRING_CONDITION(comChr, ch) ((comChr) == SECUREC_CHAR('s') && \ (!((ch) >= SECUREC_CHAR('\t') && (ch) <= SECUREC_CHAR('\r')) && (ch) != SECUREC_CHAR(' '))) /* Do not use |= optimize this code, it will cause compiling warning */ /* only supports wide characters with a maximum length of two bytes */ #define SECUREC_BRACKET_SET_BIT(table, ch) do { \ unsigned int tableIndex = SECUREC_BRACKET_INDEX(((unsigned int)(int)(ch) & SECUREC_CHAR_MASK)); \ unsigned int tableValue = SECUREC_BRACKET_VALUE(((unsigned int)(int)(ch) & SECUREC_CHAR_MASK)); \ (table)[tableIndex] = (unsigned char)((table)[tableIndex] | tableValue); \ } SECUREC_WHILE_ZERO #ifdef SECUREC_FOR_WCHAR /* table size is 32 x 256 */ #define SECUREC_BRACKET_TABLE_SIZE 8192 #define SECUREC_EOF WEOF #define SECUREC_MB_LEN 16 /* max. # bytes in multibyte char ,see MB_LEN_MAX */ /* int to unsigned int clear e571 */ #define SECUREC_IS_DIGIT(chr) (!((unsigned int)(int)(chr) & 0xff00) && isdigit(((unsigned int)(int)(chr) & 0x00ff))) #define SECUREC_IS_XDIGIT(chr) (!((unsigned int)(int)(chr) & 0xff00) && isxdigit(((unsigned int)(int)(chr) & 0x00ff))) #define SECUREC_IS_SPACE(chr) iswspace((wint_t)(int)(chr)) #else #define SECUREC_BRACKET_TABLE_SIZE 32 #define SECUREC_EOF EOF #define SECUREC_IS_DIGIT(chr) isdigit((unsigned char)(chr) & 0x00ff) #define SECUREC_IS_XDIGIT(chr) isxdigit((unsigned char)(chr) & 0x00ff) #define SECUREC_IS_SPACE(chr) isspace((unsigned char)(chr) & 0x00ff) #endif static SecInt SecSkipSpaceChar(SecFileStream *stream, int *counter); static SecInt SecGetChar(SecFileStream *stream, int *counter); static void SecUnGetChar(SecInt ch, SecFileStream *stream, int *counter); typedef struct { #ifdef SECUREC_FOR_WCHAR unsigned char *table; /* default NULL */ #else unsigned char table[SECUREC_BRACKET_TABLE_SIZE]; /* Array length is large enough in application scenarios */ #endif unsigned char mask; /* default 0 */ } SecBracketTable; #ifdef SECUREC_FOR_WCHAR #define SECUREC_INIT_BRACKET_TABLE { NULL, 0 } #else #define SECUREC_INIT_BRACKET_TABLE { { 0 }, 0 } #endif #if SECUREC_ENABLE_SCANF_FLOAT typedef struct { size_t floatStrSize; /* tialization must be length of buffer in charater */ size_t floatStrUsedLen; /* store float string len */ SecChar buffer[SECUREC_FLOAT_BUFSIZE + 1]; SecChar *floatStr; /* Initialization must point to buffer */ SecChar *allocatedFloatStr; /* Initialization must be NULL to store alloced point */ } SecFloatSpec; #endif typedef struct { SecUnsignedInt64 number64; unsigned long number; int numberWidth; /* 0 = SHORT, 1 = int, > 1 long or L_DOUBLE */ int isInt64Arg; /* 1 for 64-bit integer, 0 otherwise */ int negative; /* 0 is positive */ #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) int beyondMax; /* Non-zero means beyond */ #endif void *argPtr; /* Variable parameter pointer */ size_t arrayWidth; /* length of pointer Variable parameter, in charaters */ int width; /* width number in format */ int widthSet; /* 0 is not set width in format */ int comChr; /* Lowercase format conversion characters */ int oriComChr; /* store number conversion */ signed char isWChar; /* -1/0 not wchar, 1 for wchar */ char suppress; /* 0 is not have %* in format */ } SecScanSpec; #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) #define SECUREC_INIT_NUMBER_SPEC { 0, 0, 0, 0, 0, 0, NULL, 0, 0, 0, 0, 0, 0 } #else #define SECUREC_INIT_NUMBER_SPEC { 0, 0, 0, 0, 0, 0, NULL, 0, 0, 0, 0, 0 } #endif #ifdef SECUREC_FOR_WCHAR #define SECUREC_GETC fgetwc #define SECUREC_UN_GETC ungetwc #define SECUREC_CHAR_MASK 0xffff #else #define SECUREC_GETC fgetc #define SECUREC_UN_GETC ungetc #define SECUREC_CHAR_MASK 0xff #endif /* * Determine if it is a 64-bit pointer function * return 0 is not ,1 is 64bit pointer */ static int SecIs64BitPtr(size_t sizeOfVoidStar) { /* point size is 4 or 8 , Under the 64 bit system, the value not 0 */ /* to clear e778 */ if ((sizeOfVoidStar & sizeof(SecInt64)) != 0) { return 1; } return 0; } #if SECUREC_ENABLE_SCANF_FLOAT /* * Convert a floating point string to a floating point number */ static void SecAssignFloat(const char *floatStr, int numberWidth, void *argPtr) { char *endPtr = NULL; double d; #if SECUREC_SUPPORT_STRTOLD if (numberWidth == SECUREC_NUM_WIDTH_LONG_LONG) { long double d2 = strtold(floatStr, &endPtr); *(long double UNALIGNED *)(argPtr) = d2; return; } #endif d = strtod(floatStr, &endPtr); if (numberWidth > SECUREC_NUM_WIDTH_INT) { *(double UNALIGNED *)(argPtr) = (double)d; } else { *(float UNALIGNED *)(argPtr) = (float)d; } } #ifdef SECUREC_FOR_WCHAR /* * Convert a floating point wchar string to a floating point number * Success ret 0 */ static int SecAssignFloatW(const SecFloatSpec *floatSpec, const SecScanSpec *spec) { /* convert float string */ size_t mbsLen; size_t tempFloatStrLen = (size_t)(floatSpec->floatStrSize + 1) * sizeof(wchar_t); char *tempFloatStr = (char *)SECUREC_MALLOC(tempFloatStrLen); if (tempFloatStr == NULL) { return -1; } tempFloatStr[0] = '\0'; SECUREC_MASK_MSVC_CRT_WARNING mbsLen = wcstombs(tempFloatStr, floatSpec->floatStr, tempFloatStrLen - 1); SECUREC_END_MASK_MSVC_CRT_WARNING if (mbsLen != (size_t)-1) { tempFloatStr[mbsLen] = '\0'; SecAssignFloat(tempFloatStr, spec->numberWidth, spec->argPtr); } else { SECUREC_FREE(tempFloatStr); return -1; } SECUREC_FREE(tempFloatStr); return 0; } #endif /* * Splice floating point string * return 0 OK */ static int SecUpdateFloatString(SecChar ch, SecFloatSpec *floatSpec) { floatSpec->floatStr[floatSpec->floatStrUsedLen++] = ch; /* ch must be '0' - '9' */ if (floatSpec->floatStrUsedLen < floatSpec->floatStrSize) { return 0; } if (floatSpec->allocatedFloatStr == NULL) { /* add 1 to clear ZERO LENGTH ALLOCATIONS warning */ size_t oriBufSize = floatSpec->floatStrSize* (SECUREC_BUF_EXT_MUL * sizeof(SecChar)) + 1; void *tmpPointer = (void *)SECUREC_MALLOC(oriBufSize); if (tmpPointer == NULL) { return -1; } if (memcpy_s(tmpPointer, oriBufSize, floatSpec->floatStr, floatSpec->floatStrSize * sizeof(SecChar)) != EOK) { SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ return -1; } floatSpec->floatStr = (SecChar *) (tmpPointer); floatSpec->allocatedFloatStr = (SecChar *) (tmpPointer); /* use to clear free on stack warning */ floatSpec->floatStrSize *= SECUREC_BUF_EXT_MUL; /* this is OK, oriBufSize plus 1 just clear warning */ return 0; } else { /* LSD 2014.3.6 fix, replace realloc to malloc to avoid heap injection */ size_t oriBufSize = floatSpec->floatStrSize * sizeof(SecChar); size_t nextSize = (oriBufSize * SECUREC_BUF_EXT_MUL) + 1; /* add 1 to clear satic check tool warning */ /* Prevents integer overflow when calculating the wide character length. * The maximum length of SECUREC_MAX_WIDTH_LEN is enough */ if (nextSize <= SECUREC_MAX_WIDTH_LEN) { void *tmpPointer = (void *)SECUREC_MALLOC(nextSize); if (tmpPointer == NULL) { return -1; } if (memcpy_s(tmpPointer, nextSize, floatSpec->floatStr, oriBufSize) != EOK) { SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ return -1; } if (memset_s(floatSpec->floatStr, oriBufSize, 0, oriBufSize) != EOK) { SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ return -1; } SECUREC_FREE(floatSpec->floatStr); floatSpec->floatStr = (SecChar *) (tmpPointer); floatSpec->allocatedFloatStr = (SecChar *) (tmpPointer); /* use to clear free on stack warning */ floatSpec->floatStrSize *= SECUREC_BUF_EXT_MUL; /* this is OK, oriBufSize plus 1 just clear warning */ return 0; } } return -1; } #endif #ifndef SECUREC_FOR_WCHAR /* LSD only multi-bytes string need isleadbyte() function */ static int SecIsLeadByte(SecInt ch) { unsigned int c = (unsigned int)ch; #if !(defined(_MSC_VER) || defined(_INC_WCTYPE)) return (int)(c & 0x80); #else return (int)isleadbyte((int)(c & 0xff)); #endif } #endif /* * Parsing whether it is a wide character */ static void SecUpdateWcharFlagByType(SecUnsignedChar ch, SecScanSpec *spec) { #if defined(SECUREC_FOR_WCHAR) && (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) signed char flagForUpperType = -1; signed char flagForLowerType = 1; #else signed char flagForUpperType = 1; signed char flagForLowerType = -1; #endif /* if no l or h flag */ if (spec->isWChar == 0) { if ((ch == SECUREC_CHAR('C')) || (ch == SECUREC_CHAR('S'))) { spec->isWChar = flagForUpperType; } else { spec->isWChar = flagForLowerType; } } return; } /* * decode %l %ll */ static void SecDecodeScanQualifierL(const SecUnsignedChar **format, SecScanSpec *spec) { const SecUnsignedChar *fmt = *format; if (*(fmt + 1) == SECUREC_CHAR('l')) { spec->isInt64Arg = 1; spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; ++fmt; } else { spec->numberWidth = SECUREC_NUM_WIDTH_LONG; #if defined(SECUREC_ON_64BITS) && !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) /* on window 64 system sizeof long is 32bit */ spec->isInt64Arg = 1; #endif spec->isWChar = 1; } *format = fmt; } /* * decode %I %I43 %I64 %Id %Ii %Io ... * set finishFlag to 1 finish Flag */ static void SecDecodeScanQualifierI(const SecUnsignedChar **format, SecScanSpec *spec, int *finishFlag) { const SecUnsignedChar *fmt = *format; if ((*(fmt + 1) == SECUREC_CHAR('6')) && (*(fmt + 2) == SECUREC_CHAR('4'))) { /* offset 2 for I64 */ spec->isInt64Arg = 1; *format = *format + 2; /* add 2 to skip I64 point to '4' next loop will inc */ } else if ((*(fmt + 1) == SECUREC_CHAR('3')) && (*(fmt + 2) == SECUREC_CHAR('2'))) { /* offset 2 for I32 */ *format = *format + 2; /* add 2 to skip I32 point to '2' next loop will inc */ } else if ((*(fmt + 1) == SECUREC_CHAR('d')) || (*(fmt + 1) == SECUREC_CHAR('i')) || (*(fmt + 1) == SECUREC_CHAR('o')) || (*(fmt + 1) == SECUREC_CHAR('x')) || (*(fmt + 1) == SECUREC_CHAR('X'))) { spec->isInt64Arg = SecIs64BitPtr(sizeof(void *)); } else { /* for %I */ spec->isInt64Arg = SecIs64BitPtr(sizeof(void *)); *finishFlag = 1; } } static int SecDecodeScanWidth(const SecUnsignedChar **format, SecScanSpec *spec) { const SecUnsignedChar *fmt = *format; while (SECUREC_IS_DIGIT(*fmt)) { spec->widthSet = 1; if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(spec->width)) { return -1; } spec->width = (int)SECUREC_MUL_TEN((unsigned int)spec->width) + (unsigned char)(*fmt - SECUREC_CHAR('0')); ++fmt; } *format = fmt; return 0; } /* * init default flags for each format */ static void SecSetDefaultScanSpec(SecScanSpec *spec) { spec->number64 = 0; spec->number = 0; spec->numberWidth = SECUREC_NUM_WIDTH_INT; /* 0 = SHORT, 1 = int, > 1 long or L_DOUBLE */ spec->isInt64Arg = 0; /* 1 for 64-bit integer, 0 otherwise */ spec->negative = 0; #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) spec->beyondMax = 0; #endif spec->argPtr = NULL; spec->arrayWidth = 0; spec->width = 0; spec->widthSet = 0; spec->comChr = 0; spec->isWChar = 0; spec->suppress = 0; } /* * decode qualifier %I %L %h ... * set finishFlag to 1 finish Flag */ static void SecDecodeScanQualifier(const SecUnsignedChar **format, SecScanSpec *spec, int *finishFlag) { switch ((int)(unsigned char)(**(format))) { case SECUREC_CHAR('F'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('N'): break; case SECUREC_CHAR('h'): --spec->numberWidth; /* h for SHORT , hh for CHAR */ spec->isWChar = -1; break; #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT case SECUREC_CHAR('j'): spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; /* intmax_t or uintmax_t */ spec->isInt64Arg = 1; break; case SECUREC_CHAR('t'): /* fall-through */ /* FALLTHRU */ #endif case SECUREC_CHAR('z'): #ifdef SECUREC_ON_64BITS spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; spec->isInt64Arg = 1; #else spec->numberWidth = SECUREC_NUM_WIDTH_LONG; #endif break; case SECUREC_CHAR('L'): /* long double */ /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('q'): spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; spec->isInt64Arg = 1; break; case SECUREC_CHAR('l'): SecDecodeScanQualifierL(format, spec); break; case SECUREC_CHAR('w'): spec->isWChar = 1; break; case SECUREC_CHAR('*'): spec->suppress = 1; break; case SECUREC_CHAR('I'): SecDecodeScanQualifierI(format, spec, finishFlag); break; default: *finishFlag = 1; break; } } /* * decode width and qualifier in format */ static int SecDecodeScanFlag(const SecUnsignedChar **format, SecScanSpec *spec) { const SecUnsignedChar *fmt = *format; int finishFlag = 0; do { ++fmt; /* first skip % , next seek fmt */ /* may %*6d , so put it inside the loop */ if (SecDecodeScanWidth(&fmt, spec) != 0) { return -1; } SecDecodeScanQualifier(&fmt, spec, &finishFlag); } while (finishFlag == 0); *format = fmt; return 0; } /* * Judging whether a zeroing buffer is needed according to different formats */ static int SecDecodeClearFormat(const SecUnsignedChar *format, int *comChr) { const SecUnsignedChar *fmt = format; /* to lowercase */ int ch = (unsigned char)(*fmt) | (SECUREC_CHAR('a') - SECUREC_CHAR('A')); if (!(ch == SECUREC_CHAR('c') || ch == SECUREC_CHAR('s') || ch == SECUREC_BRACE)) { return -1; /* first argument is not a string type */ } if (ch == SECUREC_BRACE) { #if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) if (*fmt == SECUREC_CHAR('{')) { return -1; } #endif ++fmt; if (*fmt == SECUREC_CHAR('^')) { ++fmt; } if (*fmt == SECUREC_CHAR(']')) { ++fmt; } while ((*fmt != SECUREC_CHAR('\0')) && (*fmt != SECUREC_CHAR(']'))) { ++fmt; } if (*fmt == SECUREC_CHAR('\0')) { return -1; /* trunc'd format string */ } } *comChr = ch; return 0; } /* * add L'\0' for wchar string , add '\0' for char string */ static void SecAddEndingZero(void *ptr, const SecScanSpec *spec) { *(char *)ptr = '\0'; (void)spec; /* clear not use */ #if SECUREC_HAVE_WCHART if (spec->isWChar > 0) { *(wchar_t UNALIGNED *)ptr = L'\0'; } #endif } #ifdef SECUREC_FOR_WCHAR /* * Clean up the first %s %c buffer to zero for wchar version */ void SecClearDestBufW(const wchar_t *buffer, const wchar_t *format, va_list argList) #else /* * Clean up the first %s %c buffer to zero for char version */ void SecClearDestBuf(const char *buffer, const char *format, va_list argList) #endif { va_list argListSave; /* backup for argList value, this variable don't need initialized */ SecScanSpec spec; int comChr = 0; const SecUnsignedChar *fmt = (const SecUnsignedChar *)format; if (fmt == NULL) { return; } /* find first % */ while (*fmt != SECUREC_CHAR('\0') && *fmt != SECUREC_CHAR('%')) { ++fmt; } if (*fmt == SECUREC_CHAR('\0')) { return; } SecSetDefaultScanSpec(&spec); if (SecDecodeScanFlag(&fmt, &spec) != 0) { return; } /* update wchar flag for %S %C */ SecUpdateWcharFlagByType(*fmt, &spec); if (spec.suppress != 0 || SecDecodeClearFormat(fmt, &comChr) != 0) { return; } if ((buffer != NULL) && (*buffer != SECUREC_CHAR('\0')) && (comChr != SECUREC_CHAR('s'))) { /* when buffer not empty just clear %s. * example call sscanf by argment of (" \n", "%s", s, sizeof(s)) */ return; } (void)memset(&argListSave, 0, sizeof(va_list)); /* to clear e530 argListSave not initialized */ #if defined(va_copy) va_copy(argListSave, argList); #elif defined(__va_copy) /* for vxworks */ __va_copy(argListSave, argList); #else argListSave = argList; #endif do { void *argPtr = (void *)va_arg(argListSave, void *); /* Get the next argument - size of the array in characters */ size_t arrayWidth = ((size_t)(va_arg(argListSave, size_t))) & 0xFFFFFFFFUL; va_end(argListSave); /* to clear e438 last value assigned not used , the compiler will optimize this code */ (void)argListSave; /* There is no need to judge the upper limit */ if (arrayWidth == 0 || argPtr == NULL) { return; } /* clear one char */ SecAddEndingZero(argPtr, &spec); } SECUREC_WHILE_ZERO; return; } /* * Assign number to output buffer */ static void SecAssignNumber(const SecScanSpec *spec) { void *argPtr = spec->argPtr; if (spec->isInt64Arg != 0) { #if defined(SECUREC_VXWORKS_PLATFORM) #if defined(SECUREC_VXWORKS_PLATFORM_COMP) *(SecInt64 UNALIGNED *)argPtr = (SecInt64)(spec->number64); #else /* take number64 as unsigned number unsigned to int clear Compile warning */ *(SecInt64 UNALIGNED *)argPtr = *(SecUnsignedInt64 *)(&(spec->number64)); #endif #else /* take number64 as unsigned number */ *(SecInt64 UNALIGNED *)argPtr = (SecInt64)(spec->number64); #endif return; } if (spec->numberWidth > SECUREC_NUM_WIDTH_INT) { /* take number as unsigned number */ *(long UNALIGNED *)argPtr = (long)(spec->number); } else if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { *(int UNALIGNED *)argPtr = (int)(spec->number); } else if (spec->numberWidth == SECUREC_NUM_WIDTH_SHORT) { /* take number as unsigned number */ *(short UNALIGNED *)argPtr = (short)(spec->number); } else { /* < 0 for hh format modifier */ /* take number as unsigned number */ *(char UNALIGNED *)argPtr = (char)(spec->number); } } #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) /* * Judge the long bit width */ static int SecIsLongBitEqual(int bitNum) { return (unsigned int)bitNum == SECUREC_LONG_BIT_NUM; } #endif /* * Convert hexadecimal characters to decimal value */ static int SecHexValueOfChar(SecInt ch) { /* use isdigt Causing tool false alarms */ return (int)((ch >= '0' && ch <= '9') ? ((unsigned char)ch - '0') : ((((unsigned char)ch | (unsigned char)('a' - 'A')) - ('a')) + 10)); /* Adding 10 is to hex value */ } /* * Parse decimal character to integer for 32bit . */ static void SecDecodeNumberDecimal(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) unsigned long decimalEdge = SECUREC_MAX_32BITS_VALUE_DIV_TEN; #ifdef SECUREC_ON_64BITS if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { decimalEdge = (unsigned long)SECUREC_MAX_64BITS_VALUE_DIV_TEN; } #else if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { decimalEdge = SECUREC_MAX_32BITS_VALUE_DIV_TEN; } #endif if (spec->number > decimalEdge) { spec->beyondMax = 1; } #endif spec->number = SECUREC_MUL_TEN(spec->number); #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (spec->number == SECUREC_MUL_TEN(decimalEdge)) { SecUnsignedInt64 number64As = (unsigned long)SECUREC_MAX_64BITS_VALUE - spec->number; if (number64As < (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0'))) { spec->beyondMax = 1; } } #endif spec->number += (unsigned long)((SecUnsignedInt)ch - SECUREC_CHAR('0')); } /* * Parse Hex character to integer for 32bit . */ static void SecDecodeNumberHex(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (SECUREC_LONG_HEX_BEYOND_MAX(spec->number)) { spec->beyondMax = 1; } #endif spec->number = SECUREC_MUL_SIXTEEN(spec->number); spec->number += (unsigned long)(unsigned int)SecHexValueOfChar(ch); } /* * Parse Octal character to integer for 32bit . */ static void SecDecodeNumberOctal(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (SECUREC_LONG_OCTAL_BEYOND_MAX(spec->number)) { spec->beyondMax = 1; } #endif spec->number = SECUREC_MUL_EIGHT(spec->number); spec->number += (unsigned long)((SecUnsignedInt)ch - SECUREC_CHAR('0')); } #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) /* Compatible with integer negative values other than int */ static void SecFinishNumberNegativeOther(int comChr, int numberWidth, SecScanSpec *spec) { if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { if (spec->number > (unsigned long)(1ULL << (SECUREC_LONG_BIT_NUM - 1))) { spec->number = (unsigned long)(1ULL << (SECUREC_LONG_BIT_NUM - 1)); } else { spec->number = (unsigned long)(-(long)spec->number); } if (spec->beyondMax != 0) { if (numberWidth < SECUREC_NUM_WIDTH_INT) { spec->number = 0; } else if (numberWidth == SECUREC_NUM_WIDTH_LONG) { spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1))); } } } else { /* o, u, x, X, p */ spec->number = (unsigned long)(-(long)spec->number); if (spec->beyondMax != 0) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } } } /* Compatible processing of integer negative numbers */ static void SecFinishNumberNegativeInt(int comChr, SecScanSpec *spec) { if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { #ifdef SECUREC_ON_64BITS if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { if ((spec->number > SECUREC_MIN_64BITS_NEG_VALUE)) { spec->number = 0; } else { spec->number = (unsigned int)(-(int)spec->number); } } #else if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { if ((spec->number > SECUREC_MIN_32BITS_NEG_VALUE)) { spec->number = SECUREC_MIN_32BITS_NEG_VALUE; } else { spec->number = (unsigned int)(-(int)spec->number); } } #endif if (spec->beyondMax != 0) { #ifdef SECUREC_ON_64BITS if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { spec->number = 0; } #else if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { spec->number = SECUREC_MIN_32BITS_NEG_VALUE; } #endif } } else { /* o, u, x, X ,p */ #ifdef SECUREC_ON_64BITS if (spec->number > SECUREC_MAX_32BITS_VALUE_INC) { spec->number = SECUREC_MAX_32BITS_VALUE; } else { spec->number = (unsigned int)(-(int)spec->number); } #else spec->number = (unsigned int)(-(int)spec->number); #endif if (spec->beyondMax != 0) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } } } /* Compatible with integer positive values other than int */ static void SecFinishNumberPositiveOther(int comChr, int numberWidth, SecScanSpec *spec) { if (comChr == SECUREC_CHAR('d') || comChr == SECUREC_CHAR('i')) { if (spec->number > ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1)) { spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1); } if ((spec->beyondMax != 0 && numberWidth < SECUREC_NUM_WIDTH_INT)) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } if (spec->beyondMax != 0 && numberWidth == SECUREC_NUM_WIDTH_LONG) { spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1); } } else { if (spec->beyondMax != 0) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } } } /* Compatible processing of integer positive numbers */ static void SecFinishNumberPositiveInt(int comChr, SecScanSpec *spec) { if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { #ifdef SECUREC_ON_64BITS if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { if (spec->number > SECUREC_MAX_64BITS_POS_VALUE) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } } if (spec->beyondMax != 0 && SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; } #else if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { if (spec->number > SECUREC_MAX_32BITS_POS_VALUE) { spec->number = SECUREC_MAX_32BITS_POS_VALUE; } } if (spec->beyondMax != 0 && SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { spec->number = SECUREC_MAX_32BITS_POS_VALUE; } #endif } else { /* o,u,x,X,p */ if (spec->beyondMax != 0) { spec->number = SECUREC_MAX_32BITS_VALUE; } } } #endif /* * Parse decimal character to integer for 64bit . */ static void SecDecodeNumber64Decimal(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (spec->number64 > SECUREC_MAX_64BITS_VALUE_DIV_TEN) { spec->beyondMax = 1; } #endif spec->number64 = SECUREC_MUL_TEN(spec->number64); #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (spec->number64 == SECUREC_MAX_64BITS_VALUE_CUT_LAST_DIGIT) { SecUnsignedInt64 number64As = (SecUnsignedInt64)SECUREC_MAX_64BITS_VALUE - spec->number64; if (number64As < (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0'))) { spec->beyondMax = 1; } } #endif spec->number64 += (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0')); } /* * Parse Hex character to integer for 64bit . */ static void SecDecodeNumber64Hex(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (SECUREC_QWORD_HEX_BEYOND_MAX(spec->number64)) { spec->beyondMax = 1; } #endif spec->number64 = SECUREC_MUL_SIXTEEN(spec->number64); spec->number64 += (SecUnsignedInt64)(unsigned int)SecHexValueOfChar(ch); } /* * Parse Octal character to integer for 64bit . */ static void SecDecodeNumber64Octal(SecInt ch, SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (SECUREC_QWORD_OCTAL_BEYOND_MAX(spec->number64)) { spec->beyondMax = 1; } #endif spec->number64 = SECUREC_MUL_EIGHT(spec->number64); spec->number64 += (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0')); } #define SECUREC_DECODE_NUMBER_FUNC_NUM 2 /* Function name cannot add address symbol, causing 546 alarm */ static void (*g_secDecodeNumberHex[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ { SecDecodeNumberHex, SecDecodeNumber64Hex }; static void (*g_secDecodeNumberOctal[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ { SecDecodeNumberOctal, SecDecodeNumber64Octal }; static void (*g_secDecodeNumberDecimal[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ { SecDecodeNumberDecimal, SecDecodeNumber64Decimal }; /* * Parse 64-bit integer formatted input, return 0 when ch is a number. */ static int SecDecodeNumber(SecInt ch, SecScanSpec *spec) { if (spec->comChr == SECUREC_CHAR('x') || spec->comChr == SECUREC_CHAR('p')) { if (SECUREC_IS_XDIGIT(ch)) { (*g_secDecodeNumberHex[spec->isInt64Arg])(ch, spec); } else { return -1; } return 0; } if (!(SECUREC_IS_DIGIT(ch))) { return -1; } if (spec->comChr == SECUREC_CHAR('o')) { if (ch < SECUREC_CHAR('8')) { (*g_secDecodeNumberOctal[spec->isInt64Arg])(ch, spec); } else { return -1; } } else { /* comChr is 'd' */ (*g_secDecodeNumberDecimal[spec->isInt64Arg])(ch, spec); } return 0; } /* * Complete the final 32-bit integer formatted input */ static void SecFinishNumber(SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (spec->negative != 0) { if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { SecFinishNumberNegativeInt(spec->oriComChr, spec); } else { SecFinishNumberNegativeOther(spec->oriComChr, spec->numberWidth, spec); } } else { if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { SecFinishNumberPositiveInt(spec->oriComChr, spec); } else { SecFinishNumberPositiveOther(spec->oriComChr, spec->numberWidth, spec); } } #else if (spec->negative != 0) { #if defined(__hpux) if (spec->oriComChr != SECUREC_CHAR('p')) { spec->number = (unsigned long)(-(long)spec->number); } #else spec->number = (unsigned long)(-(long)spec->number); #endif } #endif return; } /* * Complete the final 64-bit integer formatted input */ static void SecFinishNumber64(SecScanSpec *spec) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) if (spec->negative != 0) { if (spec->oriComChr == (SECUREC_CHAR('d')) || (spec->oriComChr == SECUREC_CHAR('i'))) { if (spec->number64 > SECUREC_MIN_64BITS_NEG_VALUE) { spec->number64 = SECUREC_MIN_64BITS_NEG_VALUE; } else { spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); } if (spec->beyondMax != 0) { spec->number64 = SECUREC_MIN_64BITS_NEG_VALUE; } } else { /* o, u, x, X, p */ spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); if (spec->beyondMax != 0) { spec->number64 = SECUREC_MAX_64BITS_VALUE; } } } else { if ((spec->oriComChr == SECUREC_CHAR('d')) || (spec->oriComChr == SECUREC_CHAR('i'))) { if (spec->number64 > SECUREC_MAX_64BITS_POS_VALUE) { spec->number64 = SECUREC_MAX_64BITS_POS_VALUE; } if (spec->beyondMax != 0) { spec->number64 = SECUREC_MAX_64BITS_POS_VALUE; } } else { if (spec->beyondMax != 0) { spec->number64 = SECUREC_MAX_64BITS_VALUE; } } } #else if (spec->negative != 0) { #if defined(__hpux) if (spec->oriComChr != SECUREC_CHAR('p')) { spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); } #else spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); #endif } #endif return; } static void (*g_secFinishNumber[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecScanSpec *spec) = \ { SecFinishNumber, SecFinishNumber64 }; #if SECUREC_ENABLE_SCANF_FILE /* * Adjust the pointer position of the file stream */ static void SecSeekStream(SecFileStream *stream) { if ((stream->count == 0) && feof(stream->pf)) { /* file pointer at the end of file, don't need to seek back */ stream->base[0] = '\0'; return; } /* LSD seek to original position, bug fix 2014 1 21 */ if (fseek(stream->pf, stream->oriFilePos, SEEK_SET)) { /* seek failed, ignore it */ stream->oriFilePos = 0; return; } if (stream->fileRealRead > 0) { /* LSD bug fix. when file reach to EOF, don't seek back */ #if (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) int loops; for (loops = 0; loops < (stream->fileRealRead / SECUREC_BUFFERED_BLOK_SIZE); ++loops) { if (fread(stream->base, (size_t)1, (size_t)SECUREC_BUFFERED_BLOK_SIZE, stream->pf) != SECUREC_BUFFERED_BLOK_SIZE) { break; } } if ((stream->fileRealRead % SECUREC_BUFFERED_BLOK_SIZE) != 0) { size_t ret = fread(stream->base, (size_t)((unsigned int)stream->fileRealRead % SECUREC_BUFFERED_BLOK_SIZE), (size_t)1, stream->pf); if ((ret == 1 || ret == 0) && (ftell(stream->pf) < stream->oriFilePos + stream->fileRealRead)) { (void)fseek(stream->pf, stream->oriFilePos + stream->fileRealRead, SEEK_SET); } } #else /* in linux like system */ if (fseek(stream->pf, stream->oriFilePos + stream->fileRealRead, SEEK_SET)) { /* seek failed, ignore it */ stream->oriFilePos = 0; } #endif } return; } /* * Adjust the pointer position of the file stream and free memory */ static void SecAdjustStream(SecFileStream *stream) { if (stream != NULL && (stream->flag & SECUREC_FILE_STREAM_FLAG) && stream->base != NULL) { SecSeekStream(stream); SECUREC_FREE(stream->base); stream->base = NULL; } return; } #endif static void SecSkipSpaceFormat(const SecUnsignedChar **format) { const SecUnsignedChar *fmt = *format; while (SECUREC_IS_SPACE(*fmt)) { ++fmt; } *format = fmt; } #ifndef SECUREC_FOR_WCHAR /* * Handling multi-character characters */ static int SecDecodeLeadByte(SecInt ch, const SecUnsignedChar **format, SecFileStream *stream, int *counter) { #if SECUREC_HAVE_MBTOWC char temp[SECUREC_MULTI_BYTE_MAX_LEN]; const SecUnsignedChar *fmt = *format; wchar_t tempWChar = L'\0'; int ch2 = SecGetChar(stream, counter); if (*fmt == SECUREC_CHAR('\0') || (int)(*fmt) != (ch2)) { /* LSD in console mode, ungetc twice may cause problem */ SecUnGetChar(ch2, stream, counter); SecUnGetChar(ch, stream, counter); return -1; } ++fmt; if (MB_CUR_MAX >= SECUREC_UTF8_BOM_HEADER_SIZE && (((unsigned char)ch & SECUREC_UTF8_LEAD_1ST) == SECUREC_UTF8_LEAD_1ST) && (((unsigned char)ch2 & SECUREC_UTF8_LEAD_2ND) == SECUREC_UTF8_LEAD_2ND)) { /* this char is very likely to be a UTF-8 char */ int ch3 = SecGetChar(stream, counter); temp[0] = (char)ch; temp[1] = (char)ch2; /* 1 index of second character */ temp[2] = (char)ch3; /* 2 index of third character */ temp[3] = '\0'; /* 3 of string terminator position */ if (mbtowc(&tempWChar, temp, sizeof(temp)) > 0) { /* succeed */ if (*fmt == SECUREC_CHAR('\0') || (int)(*fmt) != (int)ch3) { SecUnGetChar(ch3, stream, counter); return -1; } ++fmt; *counter = *counter - 1; } else { SecUnGetChar(ch3, stream, counter); } } *counter = *counter - 1; /* only count as one character read */ *format = fmt; return 0; #else SecUnGetChar(ch, stream, counter); (void)format; return -1; #endif } #endif /* * Resolving sequence of characters from %[ format */ static int SecSetupBracketTable(const SecUnsignedChar **format, SecBracketTable *bracketTable) { const SecUnsignedChar *fmt = *format; SecUnsignedChar prevChar = 0; SecUnsignedChar expCh; SecUnsignedChar last = 0; #if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) if (*fmt == SECUREC_CHAR('{')) { return -1; } #endif /* for building "table" data */ ++fmt; /* skip [ */ bracketTable->mask = 0; if (*fmt == SECUREC_CHAR('^')) { ++fmt; bracketTable->mask = (unsigned char)0xff; } if (*fmt == SECUREC_CHAR(']')) { prevChar = SECUREC_CHAR(']'); ++fmt; SECUREC_BRACKET_SET_BIT(bracketTable->table, SECUREC_CHAR(']')); } while (*fmt != SECUREC_CHAR('\0') && *fmt != SECUREC_CHAR(']')) { expCh = *fmt++; if (expCh != SECUREC_CHAR('-') || prevChar == 0 || *fmt == SECUREC_CHAR(']')) { /* normal character */ prevChar = expCh; SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); } else { /* for %[a-z] */ expCh = *fmt++; /* get end of range */ if (prevChar < expCh) { /* %[a-z] */ last = expCh; } else { prevChar = expCh; #if (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) /* %[z-a] */ last = prevChar; #else SECUREC_BRACKET_SET_BIT(bracketTable->table, SECUREC_CHAR('-')); SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); continue; #endif } /* format %[a-\xff] last is 0xFF, condition (rnch <= last) cause dead loop */ for (expCh = prevChar; expCh < last; ++expCh) { SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); } SECUREC_BRACKET_SET_BIT(bracketTable->table, last); prevChar = 0; } } *format = fmt; return 0; } #ifdef SECUREC_FOR_WCHAR static int SecInputForWchar(SecInt ch, SecScanSpec *spec) { void *endPtr = spec->argPtr; if (spec->isWChar > 0) { *(wchar_t UNALIGNED *)endPtr = (wchar_t)ch; endPtr = (wchar_t *)endPtr + 1; --spec->arrayWidth; } else { #if SECUREC_HAVE_WCTOMB int temp; char tmpBuf[SECUREC_MB_LEN + 1]; SECUREC_MASK_MSVC_CRT_WARNING temp = wctomb(tmpBuf, (wchar_t)ch); SECUREC_END_MASK_MSVC_CRT_WARNING if (temp <= 0 || ((size_t)(unsigned int)temp) > sizeof(tmpBuf)) { /* if wctomb error, then ignore character */ return 0; } if (((size_t)(unsigned int)temp) > spec->arrayWidth) { return -1; } if (memcpy_s(endPtr, spec->arrayWidth, tmpBuf, (size_t)(unsigned int)temp) != EOK) { return -1; } endPtr = (char *)endPtr + temp; spec->arrayWidth -= (size_t)(unsigned int)temp; #else return -1; #endif } spec->argPtr = endPtr; return 0; } #endif #ifndef SECUREC_FOR_WCHAR static int SecInputForChar(SecInt ch, SecScanSpec *spec, SecFileStream *stream, int *charCount) { void *endPtr = spec->argPtr; if (spec->isWChar > 0) { wchar_t tempWChar = L'?'; /* set default char as ? */ #if SECUREC_HAVE_MBTOWC char temp[SECUREC_MULTI_BYTE_MAX_LEN + 1]; temp[0] = (char)ch; temp[1] = '\0'; #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) if (SecIsLeadByte(ch)) { temp[1] = (char)SecGetChar(stream, charCount); temp[2] = '\0'; /* 2 of string terminator position */ } if (mbtowc(&tempWChar, temp, sizeof(temp)) <= 0) { /* no string termination error for tool */ tempWChar = L'?'; } #else if (SecIsLeadByte(ch)) { int convRes = 0; int di = 1; /* in Linux like system, the string is encoded in UTF-8 */ while (convRes <= 0 && di < (int)MB_CUR_MAX && di < SECUREC_MULTI_BYTE_MAX_LEN) { temp[di++] = (char)SecGetChar(stream, charCount); temp[di] = '\0'; convRes = mbtowc(&tempWChar, temp, sizeof(temp)); } if (convRes <= 0) { tempWChar = L'?'; } } else { if (mbtowc(&tempWChar, temp, sizeof(temp)) <= 0) { /* no string termination error for tool */ tempWChar = L'?'; } } #endif #endif /* SECUREC_HAVE_MBTOWC */ *(wchar_t UNALIGNED *)endPtr = tempWChar; /* just copy L'?' if mbtowc fails, errno is set by mbtowc */ endPtr = (wchar_t *)endPtr + 1; --spec->arrayWidth; (void)charCount; (void)stream; } else { *(char *)endPtr = (char)ch; endPtr = (char *)endPtr + 1; --spec->arrayWidth; } spec->argPtr = endPtr; return 0; } #endif #if SECUREC_ENABLE_SCANF_FLOAT /* no not use localeconv()->decimal_pointif onlay support '.' */ #define SECURE_IS_FLOAT_DECIMAL(ch) ((ch) == SECUREC_CHAR('.')) /* * init SecFloatSpec befor parse format */ static void SecInitFloatSpec(SecFloatSpec *floatSpec) { floatSpec->floatStr = floatSpec->buffer; floatSpec->allocatedFloatStr = NULL; floatSpec->floatStrSize = sizeof(floatSpec->buffer) / sizeof(floatSpec->buffer[0]); floatSpec->floatStr = floatSpec->buffer; floatSpec->floatStrUsedLen = 0; } static void SecClearFloatSpec(SecFloatSpec *floatSpec, int *doneCount) { /* LSD 2014.3.6 add, clear the stack data */ if (memset_s(floatSpec->buffer, sizeof(floatSpec->buffer), 0, sizeof(floatSpec->buffer)) != EOK) { *doneCount = 0; /* This is a dead code, just to meet the coding requirements */ } if (floatSpec->allocatedFloatStr != NULL) { /* pFloatStr can be alloced in SecUpdateFloatString function, clear and free it */ if (memset_s(floatSpec->allocatedFloatStr, floatSpec->floatStrSize * sizeof(SecChar), 0, floatSpec->floatStrSize * sizeof(SecChar)) != EOK) { *doneCount = 0; /* This is a dead code, just to meet the coding requirements */ } SECUREC_FREE(floatSpec->allocatedFloatStr); floatSpec->allocatedFloatStr = NULL; floatSpec->floatStr = NULL; } } /* * scan value of exponent. * return 0 OK */ static int SecInputFloatE(SecFileStream *stream, SecScanSpec *spec, SecFloatSpec *floatSpec, int *charCount) { SecInt ch = SecGetChar(stream, charCount); if (ch == SECUREC_CHAR('+') || ch == SECUREC_CHAR('-')) { if (ch == SECUREC_CHAR('-') && SecUpdateFloatString((SecChar)'-', floatSpec) != 0) { return -1; } if (spec->width != 0) { ch = SecGetChar(stream, charCount); --spec->width; } } while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { return -1; } ch = SecGetChar(stream, charCount); } return 0; } /* * scan %f. * return 0 OK */ static int SecInputFloat(SecFileStream *stream, SecScanSpec *spec, SecFloatSpec *floatSpec, int *charCount) { int started = -1; SecInt ch = SecGetChar(stream, charCount); floatSpec->floatStrUsedLen = 0; if (ch == SECUREC_CHAR('-')) { floatSpec->floatStr[floatSpec->floatStrUsedLen++] = SECUREC_CHAR('-'); --spec->width; ch = SecGetChar(stream, charCount); } else if (ch == SECUREC_CHAR('+')) { --spec->width; ch = SecGetChar(stream, charCount); } if (spec->widthSet == 0) { /* must care width */ spec->width = -1; /* -1 is unlimited */ } /* now get integral part */ while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { started = 0; /* ch must be '0' - '9' */ if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { return -1; } ch = SecGetChar(stream, charCount); } /* now get fractional part */ if (SECURE_IS_FLOAT_DECIMAL((SecChar)ch) && spec->width-- != 0) { /* now check for decimal */ if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { return -1; } ch = SecGetChar(stream, charCount); while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { started = 0; if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { return -1; } ch = SecGetChar(stream, charCount); } } /* now get exponent part */ if (started == 0 && (ch == SECUREC_CHAR('e') || ch == SECUREC_CHAR('E')) && spec->width-- != 0) { if (SecUpdateFloatString((SecChar)'e', floatSpec) != 0) { return -1; } if (SecInputFloatE(stream, spec, floatSpec, charCount) != 0) { return -1; } } /* un set the last character that is not a floating point number */ SecUnGetChar(ch, stream, charCount); /* Make sure have a string terminator, buffer is large enough */ floatSpec->floatStr[floatSpec->floatStrUsedLen] = SECUREC_CHAR('\0'); return started; } #endif /* * scan digital part of %d %i %o %u %x %p. * return 0 OK */ static int SecInputNumberDigital(SecInt firstCh, SecFileStream *stream, SecScanSpec *spec, int *charCount) { SecInt ch = firstCh; int loopFlag = 0; int started = -1; while (loopFlag == 0) { /* decode ch to number */ loopFlag = SecDecodeNumber(ch, spec); if (loopFlag == 0) { started = 0; if (spec->widthSet != 0 && --spec->width == 0) { loopFlag = 1; } else { ch = SecGetChar(stream, charCount); } } else { SecUnGetChar(ch, stream, charCount); } } /* Handling integer negative numbers and beyond max */ (*g_secFinishNumber[spec->isInt64Arg])(spec); return started; } /* * scan %d %i %o %u %x %p. * return 0 OK */ static int SecInputNumber(SecFileStream *stream, SecScanSpec *spec, int *charCount) { SecInt ch = SecGetChar(stream, charCount); if (ch == SECUREC_CHAR('+') || ch == SECUREC_CHAR('-')) { if (ch == SECUREC_CHAR('-')) { spec->negative = 1; } if (spec->widthSet != 0 && --spec->width == 0) { return -1; } else { ch = SecGetChar(stream, charCount); } } if (spec->oriComChr == SECUREC_CHAR('i')) { /* i could be d, o, or x, use d as default */ spec->comChr = SECUREC_CHAR('d'); } if (spec->oriComChr == SECUREC_CHAR('x') || spec->oriComChr == SECUREC_CHAR('i')) { if (ch != SECUREC_CHAR('0')) { /* scan number */ return SecInputNumberDigital(ch, stream, spec, charCount); } /* now input string may be 0x123 or 0X123 or just 0 */ /* get next char */ ch = SecGetChar(stream, charCount); if ((SecChar)(ch) == SECUREC_CHAR('x') || (SecChar)ch == SECUREC_CHAR('X')) { spec->comChr = SECUREC_CHAR('x'); ch = SecGetChar(stream, charCount); /* length of 0x is 2 */ if (spec->widthSet != 0 && spec->width <= (1 + 1)) { /* length not enough for "0x" */ return -1; } spec->width -= 2; /* Subtract 2 for the length of "0x" */ } else { if (spec->oriComChr != SECUREC_CHAR('x')) { spec->comChr = SECUREC_CHAR('o'); } /* unset the character after 0 back to stream, input only '0' result is OK */ SecUnGetChar(ch, stream, charCount); ch = SECUREC_CHAR('0'); } } /* scan number */ return SecInputNumberDigital(ch, stream, spec, charCount); } /* * scan %c %s %[ * return 0 OK */ static int SecInputString(SecFileStream *stream, SecScanSpec *spec, const SecBracketTable *bracketTable, int *charCount, int *doneCount) { void *startPtr = spec->argPtr; int suppressed= 0; int errNoMem = 0; while (spec->widthSet == 0 || spec->width-- != 0) { SecInt ch = SecGetChar(stream, charCount); /* char condition or string condition and bracket condition. * only supports wide characters with a maximum length of two bytes */ if ((ch != SECUREC_EOF) && (spec->comChr == SECUREC_CHAR('c') || SECUREC_SCANF_STRING_CONDITION(spec->comChr, ch) || SECUREC_SCANF_BRACKET_CONDITION(spec->comChr, ch, bracketTable->table, bracketTable->mask))) { if (spec->suppress != 0) { /* Used to identify processed data for %* * use endPtr to identify will cause 613, so use suppressed */ suppressed = 1; continue; } /* now suppress is not set */ if (spec->arrayWidth == 0) { errNoMem = 1; /* We have exhausted the user's buffer */ break; } #ifdef SECUREC_FOR_WCHAR errNoMem = SecInputForWchar(ch, spec); #else errNoMem = SecInputForChar(ch, spec, stream, charCount); #endif if (errNoMem != 0) { break; } } else { SecUnGetChar(ch, stream, charCount); break; } } if (errNoMem != 0) { /* In case of error, blank out the input buffer */ if (spec->suppress == 0) { SecAddEndingZero(startPtr, spec); } return -1; } /* No input was scanned */ if ((spec->suppress != 0 && suppressed == 0) || (spec->suppress == 0 && startPtr == spec->argPtr)) { return -1; } if (spec->suppress == 0) { if (spec->comChr != 'c') { /* null-terminate strings */ SecAddEndingZero(spec->argPtr, spec); } *doneCount = *doneCount + 1; } return 0; } #ifdef SECUREC_FOR_WCHAR /* * alloce buffer for wchar version of %[. * return 0 OK */ static int SecAllocBracketTable(SecBracketTable *bracketTable) { if (bracketTable->table == NULL) { /* table should be freed after use */ bracketTable->table = (unsigned char *)SECUREC_MALLOC(SECUREC_BRACKET_TABLE_SIZE); if (bracketTable->table == NULL) { return -1; } } return 0; } /* * free buffer for wchar version of %[ */ static void SecFreeBracketTable(SecBracketTable *bracketTable) { if (bracketTable->table != NULL) { SECUREC_FREE(bracketTable->table); bracketTable->table = NULL; } } #endif #ifdef SECUREC_FOR_WCHAR /* * Formatting input core functions for wchar version.Called by a function such as vsscanf_s */ int SecInputSW(SecFileStream *stream, const wchar_t *cFormat, va_list argList) #else /* * Formatting input core functions for char version.Called by a function such as vswscanf_s */ int SecInputS(SecFileStream *stream, const char *cFormat, va_list argList) #endif { const SecUnsignedChar *format = (const SecUnsignedChar *)cFormat; SecBracketTable bracketTable = SECUREC_INIT_BRACKET_TABLE; SecScanSpec spec; SecInt ch = 0; int charCount = 0; int doneCount = 0; int formatError = 0; int paraIsNull = 0; #if SECUREC_ENABLE_SCANF_FLOAT SecFloatSpec floatSpec; #endif int match = 0; int errRet = 0; #if SECUREC_ENABLE_SCANF_FLOAT SecInitFloatSpec(&floatSpec); #endif /* format must not NULL */ /* use err < 1 to claer 845 */ while (errRet < 1 && *format != SECUREC_CHAR('\0')) { /* skip space in format and space in input */ if (SECUREC_IS_SPACE(*format)) { SecInt nonSpaceChar = SecSkipSpaceChar(stream, &charCount); /* eat all space chars and put fist no space char backup */ SecUnGetChar(nonSpaceChar, stream, &charCount); SecSkipSpaceFormat(&format); continue; } if (*format != SECUREC_CHAR('%')) { ch = SecGetChar(stream, &charCount); if ((int)(*format++) != (int)(ch)) { SecUnGetChar(ch, stream, &charCount); ++errRet; /* use plus to clear 845 */ continue; } #ifndef SECUREC_FOR_WCHAR if (SecIsLeadByte(ch) && SecDecodeLeadByte(ch, &format, stream, &charCount) != 0) { ++errRet; continue; } #endif /* for next %n */ if ((ch == SECUREC_EOF) && ((*format != SECUREC_CHAR('%')) || (*(format + 1) != SECUREC_CHAR('n')))) { break; } continue; } /* now *format is % */ /* set default value for each % */ SecSetDefaultScanSpec(&spec); if (SecDecodeScanFlag(&format, &spec) != 0) { formatError = 1; ++errRet; continue; } /* update wchar flag for %S %C */ SecUpdateWcharFlagByType(*format, &spec); #if SECUREC_HAVE_WCHART == 0 /* in kernel not support wide char */ if (spec.isWChar > 0) { formatError = 1; ++errRet; continue; } #endif if (spec.widthSet != 0 && spec.width == 0) { /* 0 width in format */ ++errRet; continue; } spec.comChr = (unsigned char)(*format) | (SECUREC_CHAR('a') - SECUREC_CHAR('A')); /* to lowercase */ spec.oriComChr = spec.comChr; if (spec.comChr != SECUREC_CHAR('n')) { if (spec.comChr != SECUREC_CHAR('c') && spec.comChr != SECUREC_BRACE) { ch = SecSkipSpaceChar(stream, &charCount); } else { ch = SecGetChar(stream, &charCount); } if (ch == SECUREC_EOF) { ++errRet; continue; } } /* now no 0 width in format and get one char from input */ switch (spec.comChr) { case SECUREC_CHAR('c'): /* also 'C' */ /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('s'): /* also 'S': */ /* fall-through */ /* FALLTHRU */ case SECUREC_BRACE: /* check dest buffer and size */ if (spec.suppress == 0) { spec.argPtr = (void *)va_arg(argList, void *); if (spec.argPtr == NULL) { paraIsNull = 1; ++errRet; continue; } /* Get the next argument - size of the array in characters */ #ifdef SECUREC_ON_64BITS spec.arrayWidth = ((size_t)(va_arg(argList, size_t))) & 0xFFFFFFFFUL; #else /* !SECUREC_ON_64BITS */ spec.arrayWidth = (size_t)va_arg(argList, size_t); #endif if (spec.arrayWidth == 0 || (spec.isWChar <= 0 && spec.arrayWidth > SECUREC_STRING_MAX_LEN) || (spec.isWChar > 0 && spec.arrayWidth > SECUREC_WCHAR_STRING_MAX_LEN)) { /* do not clear buffer just go error */ ++errRet; continue; } /* One element is needed for '\0' for %s and %[ */ if (spec.comChr != SECUREC_CHAR('c')) { --spec.arrayWidth; } } else { /* Set argPtr to NULL is necessary, in supress mode we don't use argPtr to store data */ spec.argPtr = NULL; } if (spec.comChr == 'c') { if (spec.widthSet == 0) { spec.widthSet = 1; spec.width = 1; } } else if (spec.comChr == SECUREC_BRACE) { /* malloc when first %[ is meet for wchar version */ #ifdef SECUREC_FOR_WCHAR if (SecAllocBracketTable(&bracketTable) != 0) { ++errRet; continue; } #endif (void)memset(bracketTable.table, 0, (size_t)SECUREC_BRACKET_TABLE_SIZE); if (SecSetupBracketTable(&format, &bracketTable) != 0) { ++errRet; continue; } if (*format == SECUREC_CHAR('\0')) { if (spec.suppress == 0 && spec.arrayWidth > 0) { SecAddEndingZero(spec.argPtr, &spec); } ++errRet; /* truncated format */ continue; } } /* un set last char to stream */ SecUnGetChar(ch, stream, &charCount); /* scanset completed. Now read string */ if (SecInputString(stream, &spec, &bracketTable, &charCount, &doneCount) != 0) { ++errRet; continue; } break; case SECUREC_CHAR('p'): /* make %hp same as %p */ spec.numberWidth = SECUREC_NUM_WIDTH_INT; #ifdef SECUREC_ON_64BITS spec.isInt64Arg = 1; #endif /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('o'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('u'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('x'): /* un set last char to stream */ SecUnGetChar(ch, stream, &charCount); if (SecInputNumber(stream, &spec, &charCount) != 0) { ++errRet; continue; } if (spec.suppress == 0) { spec.argPtr = (void *)va_arg(argList, void *); if (spec.argPtr == NULL) { paraIsNull = 1; ++errRet; continue; } SecAssignNumber(&spec); ++doneCount; } break; case SECUREC_CHAR('n'): /* char count */ if (spec.suppress == 0) { spec.argPtr = (void *)va_arg(argList, void *); if (spec.argPtr == NULL) { paraIsNull = 1; ++errRet; continue; } spec.number = (unsigned long)(unsigned int)charCount; spec.isInt64Arg = 0; SecAssignNumber(&spec); } break; case SECUREC_CHAR('e'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('f'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('g'): /* scan a float */ #if SECUREC_ENABLE_SCANF_FLOAT /* un set last char to stream */ SecUnGetChar(ch, stream, &charCount); if (SecInputFloat(stream, &spec, &floatSpec, &charCount) != 0) { ++errRet; continue; } if (spec.suppress == 0) { spec.argPtr = (void *)va_arg(argList, void *); if (spec.argPtr == NULL) { ++errRet; paraIsNull = 1; continue; } #ifdef SECUREC_FOR_WCHAR if (SecAssignFloatW(&floatSpec, &spec) != 0) { ++errRet; continue; } #else SecAssignFloat(floatSpec.floatStr, spec.numberWidth, spec.argPtr); #endif ++doneCount; } break; #else /* SECUREC_ENABLE_SCANF_FLOAT */ ++errRet; continue; #endif default: if ((int)(*format) != (int)ch) { SecUnGetChar(ch, stream, &charCount); formatError = 1; ++errRet; continue; } else { --match; } } ++match; ++format; if ((ch == SECUREC_EOF) && ((*format != SECUREC_CHAR('%')) || (*(format + 1) != SECUREC_CHAR('n')))) { break; } } #ifdef SECUREC_FOR_WCHAR SecFreeBracketTable(&bracketTable); #endif #if SECUREC_ENABLE_SCANF_FLOAT SecClearFloatSpec(&floatSpec, &doneCount); #endif #if SECUREC_ENABLE_SCANF_FILE SecAdjustStream(stream); #endif if (ch == SECUREC_EOF) { return ((doneCount || match) ? doneCount : SECUREC_SCANF_EINVAL); } else if (formatError != 0 || paraIsNull != 0) { /* Invalid Input Format or parameter */ return SECUREC_SCANF_ERROR_PARA; } return doneCount; } #if SECUREC_ENABLE_SCANF_FILE #if defined(SECUREC_NO_STD_UNGETC) /* * Get char from stdin or buffer */ static SecInt SecGetCharFromStdin(SecFileStream *stream) { SecInt ch; if (stream->fUnget == 1) { ch = (SecInt) stream->lastChar; stream->fUnget = 0; } else { ch = SECUREC_GETC(stream->pf); stream->lastChar = (unsigned int)ch; } return ch; } #else /* * Get char from stdin or buffer use std function */ static SecInt SecGetCharFromStdin(const SecFileStream *stream) { SecInt ch; ch = SECUREC_GETC(stream->pf); return ch; } #endif static void SecSkipBomHeader(SecFileStream *stream) { #ifdef SECUREC_FOR_WCHAR if (stream->count >= SECUREC_BOM_HEADER_SIZE && (((unsigned char)(stream->base[0]) == SECUREC_BOM_HEADER_LE_1ST && (unsigned char)(stream->base[1]) == SECUREC_BOM_HEADER_LE_2ST) || ((unsigned char)(stream->base[0]) == SECUREC_BOM_HEADER_BE_1ST && (unsigned char)(stream->base[1]) == SECUREC_BOM_HEADER_BE_2ST))) { /* the stream->count must be a multiple of sizeof(SecChar), * otherwise this function will return SECUREC_EOF when read the last character */ if ((stream->count - SECUREC_BOM_HEADER_SIZE) % (int)sizeof(SecChar) != 0) { int ret = (int)fread(stream->base + stream->count, (size_t)1, (size_t)SECUREC_BOM_HEADER_SIZE, stream->pf); if (ret > 0 && ret <= SECUREC_BUFFERED_BLOK_SIZE) { stream->count += ret; } } /* it's BOM header, skip */ stream->count -= SECUREC_BOM_HEADER_SIZE; stream->cur += SECUREC_BOM_HEADER_SIZE; } #else if (stream->count >= SECUREC_UTF8_BOM_HEADER_SIZE && (unsigned char)(stream->base[0]) == SECUREC_UTF8_BOM_HEADER_1ST && (unsigned char)(stream->base[1]) == SECUREC_UTF8_BOM_HEADER_2ND && (unsigned char)(stream->base[2]) == SECUREC_UTF8_BOM_HEADER_3RD) { /* 2 offset of third head character */ /* it's BOM header, skip */ stream->count -= SECUREC_UTF8_BOM_HEADER_SIZE; stream->cur += SECUREC_UTF8_BOM_HEADER_SIZE; } #endif } /* * Get char from file stream or buffer */ static SecInt SecGetCharFromFile(SecFileStream *stream) { SecInt ch; if (stream->count == 0) { int firstReadOnFile = 0; /* load file to buffer */ if (stream->base == NULL) { stream->base = (char *)SECUREC_MALLOC(SECUREC_BUFFERED_BLOK_SIZE + 1); if (stream->base == NULL) { return SECUREC_EOF; } stream->base[SECUREC_BUFFERED_BLOK_SIZE] = '\0'; /* for tool Warning string null */ } /* LSD add 2014.3.21 */ if (stream->oriFilePos == SECUREC_UNINITIALIZED_FILE_POS) { stream->oriFilePos = ftell(stream->pf); /* save original file read position */ firstReadOnFile = 1; } stream->count = (int)fread(stream->base, (size_t)1, (size_t)SECUREC_BUFFERED_BLOK_SIZE, stream->pf); stream->base[SECUREC_BUFFERED_BLOK_SIZE] = '\0'; /* for tool Warning string null */ if (stream->count == 0 || stream->count > SECUREC_BUFFERED_BLOK_SIZE) { return SECUREC_EOF; } stream->cur = stream->base; stream->flag |= SECUREC_LOAD_FILE_TO_MEM_FLAG; if (firstReadOnFile != 0) { SecSkipBomHeader(stream); } } /* according wchar_t has two bytes */ ch = (SecInt)((stream->count -= (int)sizeof(SecChar)) >= 0 ? \ (SecInt)(SECUREC_CHAR_MASK & \ (unsigned int)(int)(*((const SecChar *)(const void *)stream->cur))) : SECUREC_EOF); stream->cur += sizeof(SecChar); if (ch != SECUREC_EOF && stream->base != NULL) { stream->fileRealRead += (int)sizeof(SecChar); } return ch; } #endif /* * Get char for wchar version */ static SecInt SecGetChar(SecFileStream *stream, int *counter) { SecInt ch = SECUREC_EOF; #if SECUREC_ENABLE_SCANF_FILE if ((stream->flag & SECUREC_FROM_STDIN_FLAG) > 0) { ch = SecGetCharFromStdin(stream); } else if ((stream->flag & SECUREC_FILE_STREAM_FLAG) > 0) { ch = SecGetCharFromFile(stream); } #endif if ((stream->flag & SECUREC_MEM_STR_FLAG) > 0) { /* according wchar_t has two bytes */ ch = (SecInt)((stream->count -= (int)sizeof(SecChar)) >= 0 ? \ (SecInt)(SECUREC_CHAR_MASK & \ (unsigned int)(int)(*((const SecChar *)(const void *)stream->cur))) : SECUREC_EOF); stream->cur += sizeof(SecChar); } *counter = *counter + 1; return ch; } /* * Unget Public realizatio char for wchar and char version */ static void SecUnGetCharImpl(SecInt ch, SecFileStream *stream) { if ((stream->flag & SECUREC_FROM_STDIN_FLAG) > 0) { #if SECUREC_ENABLE_SCANF_FILE #if defined(SECUREC_NO_STD_UNGETC) stream->lastChar = (unsigned int)ch; stream->fUnget = 1; #else (void)SECUREC_UN_GETC(ch, stream->pf); #endif #else (void)ch; /* to clear e438 last value assigned not used , the compiler will optimize this code */ #endif } else if ((stream->flag & SECUREC_MEM_STR_FLAG) || (stream->flag & SECUREC_LOAD_FILE_TO_MEM_FLAG) > 0) { if (stream->cur > stream->base) { stream->cur -= sizeof(SecChar); stream->count += (int)sizeof(SecChar); } } #if SECUREC_ENABLE_SCANF_FILE if ((stream->flag & SECUREC_FILE_STREAM_FLAG) > 0 && stream->base) { stream->fileRealRead -= (int)sizeof(SecChar); } #endif } /* * Unget char for char version */ static void SecUnGetChar(SecInt ch, SecFileStream *stream, int *counter) { if (ch != SECUREC_EOF) { SecUnGetCharImpl(ch, stream); } *counter = *counter - 1; } /* * Skip space char by isspace */ static SecInt SecSkipSpaceChar(SecFileStream *stream, int *counter) { SecInt ch; do { ch = SecGetChar(stream, counter); } while (ch != SECUREC_EOF && SECUREC_IS_SPACE(ch)); return ch; } #endif /* __INPUT_INL__5D13A042_DC3F_4ED9_A8D1_882811274C27 */ ================================================ FILE: third_party/securec/src/memcpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" #ifndef SECUREC_MEMCOPY_WITH_PERFORMANCE #define SECUREC_MEMCOPY_WITH_PERFORMANCE 0 #endif #if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMCOPY_WITH_PERFORMANCE #ifndef SECUREC_MEMCOPY_THRESHOLD_SIZE #define SECUREC_MEMCOPY_THRESHOLD_SIZE 64UL #endif /* * Determine whether the address is 8-byte aligned, use static to increase performance * return 0 is aligned */ static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) { return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ } #define SECUREC_SMALL_MEM_COPY do { \ if (SECUREC_ADDR_ALIGNED_8(dest) && SECUREC_ADDR_ALIGNED_8(src)) { \ /* use struct assignment */ \ switch (count) { \ case 1: \ *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)src; \ break; \ case 2: \ *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)src; \ break; \ case 3: \ *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)src; \ break; \ case 4: \ *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)src; \ break; \ case 5: \ *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)src; \ break; \ case 6: \ *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)src; \ break; \ case 7: \ *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)src; \ break; \ case 8: \ *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)src; \ break; \ case 9: \ *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)src; \ break; \ case 10: \ *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)src; \ break; \ case 11: \ *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)src; \ break; \ case 12: \ *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)src; \ break; \ case 13: \ *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)src; \ break; \ case 14: \ *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)src; \ break; \ case 15: \ *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)src; \ break; \ case 16: \ *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)src; \ break; \ case 17: \ *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)src; \ break; \ case 18: \ *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)src; \ break; \ case 19: \ *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)src; \ break; \ case 20: \ *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)src; \ break; \ case 21: \ *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)src; \ break; \ case 22: \ *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)src; \ break; \ case 23: \ *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)src; \ break; \ case 24: \ *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)src; \ break; \ case 25: \ *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)src; \ break; \ case 26: \ *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)src; \ break; \ case 27: \ *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)src; \ break; \ case 28: \ *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)src; \ break; \ case 29: \ *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)src; \ break; \ case 30: \ *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)src; \ break; \ case 31: \ *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)src; \ break; \ case 32: \ *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)src; \ break; \ case 33: \ *(SecStrBuf33 *)dest = *(const SecStrBuf33 *)src; \ break; \ case 34: \ *(SecStrBuf34 *)dest = *(const SecStrBuf34 *)src; \ break; \ case 35: \ *(SecStrBuf35 *)dest = *(const SecStrBuf35 *)src; \ break; \ case 36: \ *(SecStrBuf36 *)dest = *(const SecStrBuf36 *)src; \ break; \ case 37: \ *(SecStrBuf37 *)dest = *(const SecStrBuf37 *)src; \ break; \ case 38: \ *(SecStrBuf38 *)dest = *(const SecStrBuf38 *)src; \ break; \ case 39: \ *(SecStrBuf39 *)dest = *(const SecStrBuf39 *)src; \ break; \ case 40: \ *(SecStrBuf40 *)dest = *(const SecStrBuf40 *)src; \ break; \ case 41: \ *(SecStrBuf41 *)dest = *(const SecStrBuf41 *)src; \ break; \ case 42: \ *(SecStrBuf42 *)dest = *(const SecStrBuf42 *)src; \ break; \ case 43: \ *(SecStrBuf43 *)dest = *(const SecStrBuf43 *)src; \ break; \ case 44: \ *(SecStrBuf44 *)dest = *(const SecStrBuf44 *)src; \ break; \ case 45: \ *(SecStrBuf45 *)dest = *(const SecStrBuf45 *)src; \ break; \ case 46: \ *(SecStrBuf46 *)dest = *(const SecStrBuf46 *)src; \ break; \ case 47: \ *(SecStrBuf47 *)dest = *(const SecStrBuf47 *)src; \ break; \ case 48: \ *(SecStrBuf48 *)dest = *(const SecStrBuf48 *)src; \ break; \ case 49: \ *(SecStrBuf49 *)dest = *(const SecStrBuf49 *)src; \ break; \ case 50: \ *(SecStrBuf50 *)dest = *(const SecStrBuf50 *)src; \ break; \ case 51: \ *(SecStrBuf51 *)dest = *(const SecStrBuf51 *)src; \ break; \ case 52: \ *(SecStrBuf52 *)dest = *(const SecStrBuf52 *)src; \ break; \ case 53: \ *(SecStrBuf53 *)dest = *(const SecStrBuf53 *)src; \ break; \ case 54: \ *(SecStrBuf54 *)dest = *(const SecStrBuf54 *)src; \ break; \ case 55: \ *(SecStrBuf55 *)dest = *(const SecStrBuf55 *)src; \ break; \ case 56: \ *(SecStrBuf56 *)dest = *(const SecStrBuf56 *)src; \ break; \ case 57: \ *(SecStrBuf57 *)dest = *(const SecStrBuf57 *)src; \ break; \ case 58: \ *(SecStrBuf58 *)dest = *(const SecStrBuf58 *)src; \ break; \ case 59: \ *(SecStrBuf59 *)dest = *(const SecStrBuf59 *)src; \ break; \ case 60: \ *(SecStrBuf60 *)dest = *(const SecStrBuf60 *)src; \ break; \ case 61: \ *(SecStrBuf61 *)dest = *(const SecStrBuf61 *)src; \ break; \ case 62: \ *(SecStrBuf62 *)dest = *(const SecStrBuf62 *)src; \ break; \ case 63: \ *(SecStrBuf63 *)dest = *(const SecStrBuf63 *)src; \ break; \ case 64: \ *(SecStrBuf64 *)dest = *(const SecStrBuf64 *)src; \ break; \ default: \ break; \ } /* END switch */ \ } else { \ char *tmpDest = (char *)dest; \ const char *tmpSrc = (const char *)src; \ switch (count) { \ case 64: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 63: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 62: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 61: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 60: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 59: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 58: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 57: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 56: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 55: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 54: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 53: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 52: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 51: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 50: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 49: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 48: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 47: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 46: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 45: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 44: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 43: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 42: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 41: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 40: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 39: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 38: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 37: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 36: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 35: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 34: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 33: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 32: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 31: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 30: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 29: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 28: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 27: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 26: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 25: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 24: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 23: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 22: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 21: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 20: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 19: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 18: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 17: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 16: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 15: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 14: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 13: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 12: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 11: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 10: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 9: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 8: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 7: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 6: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 5: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 4: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 3: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 2: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 1: \ *(tmpDest++) = *(tmpSrc++); \ /* fall-through */ /* FALLTHRU */ \ default: \ break; \ } \ } \ } SECUREC_WHILE_ZERO #endif /* * Handling errors */ static errno_t SecMemcpyError(void *dest, size_t destMax, const void *src, size_t count) { if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("memcpy_s"); return ERANGE; } if (dest == NULL || src == NULL) { SECUREC_ERROR_INVALID_PARAMTER("memcpy_s"); if (dest != NULL) { (void)memset(dest, 0, destMax); return EINVAL_AND_RESET; } return EINVAL; } if (count > destMax) { (void)memset(dest, 0, destMax); SECUREC_ERROR_INVALID_RANGE("memcpy_s"); return ERANGE_AND_RESET; } if (dest == src) { return EOK; } if ((dest > src && dest < (const void *)((const unsigned char *)src + count)) || \ (src > dest && src < (void *)((unsigned char *)dest + count))) { (void)memset(dest, 0, destMax); SECUREC_ERROR_BUFFER_OVERLAP("memcpy_s"); return EOVERLAP_AND_RESET; } /* count == 0 also return EOK */ return EOK; } #if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMCOPY_WITH_PERFORMANCE /* * Performance optimization */ static void SecDoMemcpyOpt(void *dest, const void *src, size_t count) { if (count > SECUREC_MEMCOPY_THRESHOLD_SIZE) { SecDoMemcpy(dest, src, count); } else { SECUREC_SMALL_MEM_COPY; } return; } #endif #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) /* fread API in windows will call memcpy_s and pass 0xffffffff to destMax. * To avoid the failure of fread, we don't check desMax limit. */ #define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ (dest) != NULL && (src) != NULL && \ (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) #else #define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ (dest) != NULL && (src) != NULL && \ (destMax) <= SECUREC_MEM_MAX_LEN && \ (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) #endif /* * * The memcpy_s function copies n characters from the object pointed to by src into the object pointed to by dest * * * dest Destination buffer. * destMax Size of the destination buffer. * src Buffer to copy from. * count Number of characters to copy * * * dest buffer is updated. * * * EOK Success * EINVAL dest is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * ERANGE destMax > SECUREC_MEM_MAX_LEN or destMax is 0 * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * and dest != NULL and src != NULL * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and * count <= destMax destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL * and src != NULL and dest != src * * if an error occured, dest will be filled with 0. * If the source and destination overlap, the behavior of memcpy_s is undefined. * Use memmove_s to handle overlapping regions. */ errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count) { if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { #if SECUREC_MEMCOPY_WITH_PERFORMANCE SecDoMemcpyOpt(dest, src, count); #else SecDoMemcpy(dest, src, count); #endif return EOK; } /* meet some runtime violation, return error code */ return SecMemcpyError(dest, destMax, src, count); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(memcpy_s); #endif #if SECUREC_WITH_PERFORMANCE_ADDONS /* * Performance optimization */ errno_t memcpy_sOptAsm(void *dest, size_t destMax, const void *src, size_t count) { if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { SecDoMemcpyOpt(dest, src, count); return EOK; } /* meet some runtime violation, return error code */ return SecMemcpyError(dest, destMax, src, count); } /* trim judgement on "destMax <= SECUREC_MEM_MAX_LEN" */ errno_t memcpy_sOptTc(void *dest, size_t destMax, const void *src, size_t count) { if (SECUREC_LIKELY(count <= destMax && dest != NULL && src != NULL && \ count > 0 && \ ((dest > src && (const void *)((const unsigned char *)src + count) <= dest) || \ (src > dest && (void *)((unsigned char *)dest + count) <= src)))) { SecDoMemcpyOpt(dest, src, count); return EOK; } /* meet some runtime violation, return error code */ return SecMemcpyError(dest, destMax, src, count); } #endif ================================================ FILE: third_party/securec/src/memmove_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securecutil.h" #ifdef SECUREC_NOT_CALL_LIBC_CORE_API /* * Implementing memory data movement */ static void SecUtilMemmove(void *dst, const void *src, size_t count) { unsigned char *pDest = (unsigned char *)dst; const unsigned char *pSrc = (const unsigned char *)src; size_t maxCount = count; if (dst <= src || pDest >= (pSrc + maxCount)) { /* * Non-Overlapping Buffers * copy from lower addresses to higher addresses */ while (maxCount--) { *pDest = *pSrc; ++pDest; ++pSrc; } } else { /* * Overlapping Buffers * copy from higher addresses to lower addresses */ pDest = pDest + maxCount - 1; pSrc = pSrc + maxCount - 1; while (maxCount--) { *pDest = *pSrc; --pDest; --pSrc; } } } #endif /* * * The memmove_s function copies count bytes of characters from src to dest. * This function can be assigned correctly when memory overlaps. * * dest Destination object. * destMax Size of the destination buffer. * src Source object. * count Number of characters to copy. * * * dest buffer is uptdated. * * * EOK Success * EINVAL dest is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * ERANGE destMax > SECUREC_MEM_MAX_LEN or destMax is 0 * ERANGE_AND_RESET count > destMax and dest != NULL and src != NULL and destMax != 0 * and destMax <= SECUREC_MEM_MAX_LEN * * If an error occured, dest will be filled with 0 when dest and destMax valid. * If some regions of the source area and the destination overlap, memmove_s * ensures that the original source bytes in the overlapping region are copied * before being overwritten. */ errno_t memmove_s(void *dest, size_t destMax, const void *src, size_t count) { if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("memmove_s"); return ERANGE; } if (dest == NULL || src == NULL) { SECUREC_ERROR_INVALID_PARAMTER("memmove_s"); if (dest != NULL) { (void)memset(dest, 0, destMax); return EINVAL_AND_RESET; } return EINVAL; } if (count > destMax) { (void)memset(dest, 0, destMax); SECUREC_ERROR_INVALID_RANGE("memmove_s"); return ERANGE_AND_RESET; } if (dest == src) { return EOK; } if (count > 0) { #ifdef SECUREC_NOT_CALL_LIBC_CORE_API SecUtilMemmove(dest, src, count); #else /* use underlying memmove for performance consideration */ (void)memmove(dest, src, count); #endif } return EOK; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(memmove_s); #endif ================================================ FILE: third_party/securec/src/memset_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMSET 1 #include "securecutil.h" #ifndef SECUREC_MEMSET_WITH_PERFORMANCE #define SECUREC_MEMSET_WITH_PERFORMANCE 0 #endif #define SECUREC_MEMSET_PARAM_OK(dest, destMax, count) (SECUREC_LIKELY((count) <= (destMax) && \ (dest) != NULL && (destMax) <= SECUREC_MEM_MAX_LEN)) #if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMSET_WITH_PERFORMANCE /* * Determine whether the address is 8-byte aligned, use static to increase performance * return 0 is aligned */ static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) { return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ } /* use union to clear strict-aliasing warning */ typedef union { SecStrBuf32 buf32; SecStrBuf31 buf31; SecStrBuf30 buf30; SecStrBuf29 buf29; SecStrBuf28 buf28; SecStrBuf27 buf27; SecStrBuf26 buf26; SecStrBuf25 buf25; SecStrBuf24 buf24; SecStrBuf23 buf23; SecStrBuf22 buf22; SecStrBuf21 buf21; SecStrBuf20 buf20; SecStrBuf19 buf19; SecStrBuf18 buf18; SecStrBuf17 buf17; SecStrBuf16 buf16; SecStrBuf15 buf15; SecStrBuf14 buf14; SecStrBuf13 buf13; SecStrBuf12 buf12; SecStrBuf11 buf11; SecStrBuf10 buf10; SecStrBuf9 buf9; SecStrBuf8 buf8; SecStrBuf7 buf7; SecStrBuf6 buf6; SecStrBuf5 buf5; SecStrBuf4 buf4; SecStrBuf3 buf3; SecStrBuf2 buf2; SecStrBuf1 buf1; } SecStrBuf32Union; /* C standard initializes the first member of the consortium. */ static const SecStrBuf32 g_allZero = {{ '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0' }}; static const SecStrBuf32 g_allFF = {{ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }}; static const SecStrBuf32Union *SecStrictAliasingCast(const SecStrBuf32 *buf) { return (const SecStrBuf32Union *)buf; } #ifndef SECUREC_MEMSET_THRESHOLD_SIZE #define SECUREC_MEMSET_THRESHOLD_SIZE 32UL #endif #define SECUREC_UNALIGNED_SET do { \ char *pcDest = (char *)dest; \ switch (count) { \ case 32: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 31: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 30: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 29: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 28: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 27: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 26: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 25: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 24: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 23: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 22: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 21: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 20: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 19: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 18: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 17: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 16: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 15: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 14: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 13: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 12: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 11: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 10: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 9: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 8: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 7: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 6: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 5: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 4: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 3: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 2: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ case 1: \ *(pcDest++) = (char)c; \ /* fall-through */ /* FALLTHRU */ \ default: \ break; \ } \ } SECUREC_WHILE_ZERO #define SECUREC_ALIGNED_SET_OPT_ZERO_FF do { \ switch (c) { \ case 0: \ switch (count) { \ case 1: \ *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)(&((SecStrictAliasingCast(&g_allZero))->buf1)); \ break; \ case 2: \ *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)(&((SecStrictAliasingCast(&g_allZero))->buf2)); \ break; \ case 3: \ *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)(&((SecStrictAliasingCast(&g_allZero))->buf3)); \ break; \ case 4: \ *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)(&((SecStrictAliasingCast(&g_allZero))->buf4)); \ break; \ case 5: \ *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)(&((SecStrictAliasingCast(&g_allZero))->buf5)); \ break; \ case 6: \ *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)(&((SecStrictAliasingCast(&g_allZero))->buf6)); \ break; \ case 7: \ *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)(&((SecStrictAliasingCast(&g_allZero))->buf7)); \ break; \ case 8: \ *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)(&((SecStrictAliasingCast(&g_allZero))->buf8)); \ break; \ case 9: \ *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)(&((SecStrictAliasingCast(&g_allZero))->buf9)); \ break; \ case 10: \ *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)(&((SecStrictAliasingCast(&g_allZero))->buf10)); \ break; \ case 11: \ *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)(&((SecStrictAliasingCast(&g_allZero))->buf11)); \ break; \ case 12: \ *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)(&((SecStrictAliasingCast(&g_allZero))->buf12)); \ break; \ case 13: \ *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)(&((SecStrictAliasingCast(&g_allZero))->buf13)); \ break; \ case 14: \ *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)(&((SecStrictAliasingCast(&g_allZero))->buf14)); \ break; \ case 15: \ *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)(&((SecStrictAliasingCast(&g_allZero))->buf15)); \ break; \ case 16: \ *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)(&((SecStrictAliasingCast(&g_allZero))->buf16)); \ break; \ case 17: \ *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)(&((SecStrictAliasingCast(&g_allZero))->buf17)); \ break; \ case 18: \ *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)(&((SecStrictAliasingCast(&g_allZero))->buf18)); \ break; \ case 19: \ *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)(&((SecStrictAliasingCast(&g_allZero))->buf19)); \ break; \ case 20: \ *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)(&((SecStrictAliasingCast(&g_allZero))->buf20)); \ break; \ case 21: \ *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)(&((SecStrictAliasingCast(&g_allZero))->buf21)); \ break; \ case 22: \ *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)(&((SecStrictAliasingCast(&g_allZero))->buf22)); \ break; \ case 23: \ *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)(&((SecStrictAliasingCast(&g_allZero))->buf23)); \ break; \ case 24: \ *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)(&((SecStrictAliasingCast(&g_allZero))->buf24)); \ break; \ case 25: \ *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)(&((SecStrictAliasingCast(&g_allZero))->buf25)); \ break; \ case 26: \ *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)(&((SecStrictAliasingCast(&g_allZero))->buf26)); \ break; \ case 27: \ *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)(&((SecStrictAliasingCast(&g_allZero))->buf27)); \ break; \ case 28: \ *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)(&((SecStrictAliasingCast(&g_allZero))->buf28)); \ break; \ case 29: \ *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)(&((SecStrictAliasingCast(&g_allZero))->buf29)); \ break; \ case 30: \ *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)(&((SecStrictAliasingCast(&g_allZero))->buf30)); \ break; \ case 31: \ *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)(&((SecStrictAliasingCast(&g_allZero))->buf31)); \ break; \ case 32: \ *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)(&((SecStrictAliasingCast(&g_allZero))->buf32)); \ break; \ default: \ break; \ } \ break; \ case 0xFF: \ switch (count) { \ case 1: \ *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)(&((SecStrictAliasingCast(&g_allFF))->buf1)); \ break; \ case 2: \ *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)(&((SecStrictAliasingCast(&g_allFF))->buf2)); \ break; \ case 3: \ *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)(&((SecStrictAliasingCast(&g_allFF))->buf3)); \ break; \ case 4: \ *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)(&((SecStrictAliasingCast(&g_allFF))->buf4)); \ break; \ case 5: \ *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)(&((SecStrictAliasingCast(&g_allFF))->buf5)); \ break; \ case 6: \ *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)(&((SecStrictAliasingCast(&g_allFF))->buf6)); \ break; \ case 7: \ *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)(&((SecStrictAliasingCast(&g_allFF))->buf7)); \ break; \ case 8: \ *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)(&((SecStrictAliasingCast(&g_allFF))->buf8)); \ break; \ case 9: \ *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)(&((SecStrictAliasingCast(&g_allFF))->buf9)); \ break; \ case 10: \ *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)(&((SecStrictAliasingCast(&g_allFF))->buf10)); \ break; \ case 11: \ *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)(&((SecStrictAliasingCast(&g_allFF))->buf11)); \ break; \ case 12: \ *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)(&((SecStrictAliasingCast(&g_allFF))->buf12)); \ break; \ case 13: \ *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)(&((SecStrictAliasingCast(&g_allFF))->buf13)); \ break; \ case 14: \ *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)(&((SecStrictAliasingCast(&g_allFF))->buf14)); \ break; \ case 15: \ *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)(&((SecStrictAliasingCast(&g_allFF))->buf15)); \ break; \ case 16: \ *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)(&((SecStrictAliasingCast(&g_allFF))->buf16)); \ break; \ case 17: \ *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)(&((SecStrictAliasingCast(&g_allFF))->buf17)); \ break; \ case 18: \ *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)(&((SecStrictAliasingCast(&g_allFF))->buf18)); \ break; \ case 19: \ *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)(&((SecStrictAliasingCast(&g_allFF))->buf19)); \ break; \ case 20: \ *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)(&((SecStrictAliasingCast(&g_allFF))->buf20)); \ break; \ case 21: \ *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)(&((SecStrictAliasingCast(&g_allFF))->buf21)); \ break; \ case 22: \ *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)(&((SecStrictAliasingCast(&g_allFF))->buf22)); \ break; \ case 23: \ *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)(&((SecStrictAliasingCast(&g_allFF))->buf23)); \ break; \ case 24: \ *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)(&((SecStrictAliasingCast(&g_allFF))->buf24)); \ break; \ case 25: \ *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)(&((SecStrictAliasingCast(&g_allFF))->buf25)); \ break; \ case 26: \ *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)(&((SecStrictAliasingCast(&g_allFF))->buf26)); \ break; \ case 27: \ *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)(&((SecStrictAliasingCast(&g_allFF))->buf27)); \ break; \ case 28: \ *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)(&((SecStrictAliasingCast(&g_allFF))->buf28)); \ break; \ case 29: \ *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)(&((SecStrictAliasingCast(&g_allFF))->buf29)); \ break; \ case 30: \ *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)(&((SecStrictAliasingCast(&g_allFF))->buf30)); \ break; \ case 31: \ *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)(&((SecStrictAliasingCast(&g_allFF))->buf31)); \ break; \ case 32: \ *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)(&((SecStrictAliasingCast(&g_allFF))->buf32)); \ break; \ default: \ break; \ } \ break; \ default: \ SECUREC_UNALIGNED_SET; \ } /* END switch */ \ } SECUREC_WHILE_ZERO #endif /* * Handling errors */ static errno_t SecMemsetError(void *dest, size_t destMax, int c, size_t count) { if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("memset_s"); return ERANGE; } if (dest == NULL) { SECUREC_ERROR_INVALID_PARAMTER("memset_s"); return EINVAL; } if (count > destMax) { (void)memset(dest, c, destMax); /* set entire buffer to value c */ SECUREC_ERROR_INVALID_RANGE("memset_s"); return ERANGE_AND_RESET; } return EOK; } #if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMSET_WITH_PERFORMANCE /* * Performance optimization */ static void SecDoMemsetOpt(void *dest, int c, size_t count) { if (count > SECUREC_MEMSET_THRESHOLD_SIZE) { SecDoMemset(dest, c, count); } else { if (SECUREC_ADDR_ALIGNED_8(dest)) { /* use struct assignment */ SECUREC_ALIGNED_SET_OPT_ZERO_FF; } else { SECUREC_UNALIGNED_SET; } } return; } #endif /* * * The memset_s function copies the value of c (converted to an unsigned char) * into each of the first count characters of the object pointed to by dest. * * * dest Pointer to destination. * destMax The size of the buffer. * c Character to set. * count Number of characters. * * * dest buffer is uptdated. * * * EOK Success * EINVAL dest == NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN * ERANGE destMax is 0 or destMax > SECUREC_MEM_MAX_LEN * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL * * if return ERANGE_AND_RESET then fill dest to c ,fill length is destMax */ errno_t memset_s(void *dest, size_t destMax, int c, size_t count) { if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { #if SECUREC_MEMSET_WITH_PERFORMANCE SecDoMemsetOpt(dest, c, count); #else SecDoMemset(dest, c, count); #endif return EOK; } else { /* meet some runtime violation, return error code */ return SecMemsetError(dest, destMax, c, count); } } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(memset_s); #endif #if SECUREC_WITH_PERFORMANCE_ADDONS /* * Performance optimization */ errno_t memset_sOptAsm(void *dest, size_t destMax, int c, size_t count) { if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { SecDoMemsetOpt(dest, c, count); return EOK; } /* meet some runtime violation, return error code */ return SecMemsetError(dest, destMax, c, count); } /* * Performance optimization */ errno_t memset_sOptTc(void *dest, size_t destMax, int c, size_t count) { if (SECUREC_LIKELY(count <= destMax && dest != NULL)) { SecDoMemsetOpt(dest, c, count); return EOK; } /* meet some runtime violation, return error code */ return SecMemsetError(dest, destMax, c, count); } #endif ================================================ FILE: third_party/securec/src/output.inl ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 #define OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 #define SECUREC_NULL_STRING_SIZE 8 #define SECUREC_STATE_TABLE_SIZE 337 #define SECUREC_OFFSET_BITS_WORD 16 #define SECUREC_OFFSET_BITS_DWORD 32 #define SECUREC_OFFSET_DIV_OCTAL 3 #define SECUREC_OFFSET_DIV_HEX 4 #define SECUREC_RADIX_OCTAL 8 #define SECUREC_RADIX_DECIMAL 10 #define SECUREC_RADIX_HEX 16 /* Use two displacements to eliminate compilation warnings */ #define SECUREC_SHR_DWORD(x) (((x) >> 16) >> 16) #define SECUREC_PREFIX_LEN 2 /* size include '+' and '\0' */ #define SECUREC_FLOAT_BUF_EXT 2 #ifdef SECUREC_STACK_SIZE_LESS_THAN_1K #define SECUREC_FMT_STR_LEN 8 #else #define SECUREC_FMT_STR_LEN 16 #endif typedef struct { unsigned int flags; int fldWidth; int precision; int bufferIsWide; /* flag for buffer contains wide chars ;0 is not wide char */ int dynWidth; /* %* 1 width from variable parameter ;0 not */ int dynPrecision; /* %.* 1 precision from variable parameter ;0 not */ } SecFormatAttr; typedef union { char *str; /* not a null terminated string */ #if SECUREC_HAVE_WCHART wchar_t *wStr; #endif } SecFormatBuf; typedef union { char str[SECUREC_BUFFER_SIZE + 1]; #ifdef SECUREC_FOR_WCHAR wchar_t wStr[SECUREC_BUFFER_SIZE + 1]; #endif } SecBuffer; #if SECUREC_ENABLE_SPRINTF_FLOAT /* call system sprintf to format float value */ static int SecIndirectSprintf(char *strDest, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); SECUREC_MASK_MSVC_CRT_WARNING ret = vsprintf(strDest, format, argList); SECUREC_END_MASK_MSVC_CRT_WARNING va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT /* out put long double value to dest */ static int SecFormatLongDboule(char *strDest,const SecFormatAttr *formatAttr, const char *fmt, long double ldValue) { int fldWidth = ((formatAttr->flags & SECUREC_FLAG_LEFT) ? (-(formatAttr->fldWidth)) : formatAttr->fldWidth); if (formatAttr->dynWidth && formatAttr->dynPrecision) { return SecIndirectSprintf(strDest, fmt, fldWidth, formatAttr->precision, ldValue); } else if (formatAttr->dynWidth) { return SecIndirectSprintf(strDest, fmt, fldWidth, ldValue); } else if (formatAttr->dynPrecision) { return SecIndirectSprintf(strDest, fmt, formatAttr->precision, ldValue); } return SecIndirectSprintf(strDest, fmt, ldValue); } #endif /* out put double value to dest */ static int SecFormatDboule(char *strDest, const SecFormatAttr *formatAttr, const char *fmt, double dValue) { int fldWidth = ((formatAttr->flags & SECUREC_FLAG_LEFT) ? (-(formatAttr->fldWidth)) : formatAttr->fldWidth); if (formatAttr->dynWidth && formatAttr->dynPrecision) { return SecIndirectSprintf(strDest, fmt, fldWidth, formatAttr->precision, dValue); } else if (formatAttr->dynWidth) { return SecIndirectSprintf(strDest, fmt, fldWidth, dValue); } else if (formatAttr->dynPrecision) { return SecIndirectSprintf(strDest, fmt, formatAttr->precision, dValue); } return SecIndirectSprintf(strDest, fmt, dValue); } #endif #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT /* to clear e506 warning */ static int SecIsSameSize(size_t sizeA, size_t sizeB) { return sizeA == sizeB; } #endif #define SECUREC_SPECIAL_DWORD(val32, numBase) do { \ --formatBuf.str; \ *(formatBuf.str) = digits[(val32) % (numBase)]; \ } while (((val32) /= (numBase)) != 0) #if defined(SECUREC_USE_SPECIAL_DIV64) || (defined(SECUREC_VXWORKS_VERSION_5_4) && !defined(SECUREC_ON_64BITS)) /* * Fast divide by 10 algorithm. * Calculation divisor multiply 0xcccccccccccccccdULL, resultHi64 >> 3 as quotient */ static void SecU64Div10(SecUnsignedInt64 divisor, SecUnsignedInt64 *quotient, SecUnsignedInt32 *remainder) { SecUnsignedInt64 mask = 0xffffffffULL; /* use 0xffffffffULL as 32 bit mask */ SecUnsignedInt64 magicHi = 0xccccccccULL; /* fast divide 10 magic numbers high 32bit 0xccccccccULL */ SecUnsignedInt64 magicLow = 0xcccccccdULL; /* fast divide 10 magic numbers low 32bit 0xcccccccdULL */ SecUnsignedInt64 divisorHi = (SecUnsignedInt64)(SECUREC_SHR_DWORD(divisor)); /* hig 32 bit use */ SecUnsignedInt64 divisorLow = (SecUnsignedInt64)(divisor & mask); /* low 32 bit mask */ SecUnsignedInt64 factorHi = divisorHi * magicHi; SecUnsignedInt64 factorLow1 = divisorHi * magicLow; SecUnsignedInt64 factorLow2 = divisorLow * magicHi; SecUnsignedInt64 factorLow3 = divisorLow * magicLow; SecUnsignedInt64 carry = (factorLow1 & mask) + (factorLow2 & mask) + SECUREC_SHR_DWORD(factorLow3); SecUnsignedInt64 resultHi64 = factorHi + SECUREC_SHR_DWORD(factorLow1) + \ SECUREC_SHR_DWORD(factorLow2) + SECUREC_SHR_DWORD(carry); *quotient = resultHi64 >> 3; /* fast divide 10 magic numbers 3 */ *remainder = (SecUnsignedInt32)(divisor - ((*quotient) * 10)); /* quotient mul 10 */ return; } #if defined(SECUREC_VXWORKS_VERSION_5_4) && !defined(SECUREC_ON_64BITS) /* * Divide function for VXWORKS */ static int SecU64Div32(SecUnsignedInt64 divisor, SecUnsignedInt32 radix, SecUnsignedInt64 *quotient, SecUnsignedInt32 *remainder) { switch (radix) { case SECUREC_RADIX_DECIMAL: SecU64Div10(divisor, quotient, remainder); break; case SECUREC_RADIX_HEX: *quotient = divisor >> SECUREC_OFFSET_DIV_HEX; *remainder = divisor & 0xfULL; /* mask one hex number by 0xfULL */ break; case SECUREC_RADIX_OCTAL: *quotient = divisor >> SECUREC_OFFSET_DIV_OCTAL; *remainder = divisor & 0x7ULL; /* mask one hex number by 0x7ULL */ break; default: return -1; } return 0; } #endif #endif #if defined(SECUREC_USE_SPECIAL_DIV64) /* The compiler does not provide 64 bit division problems */ #define SECUREC_SPECIAL_QWORD_BASE10(val64) do { \ SecUnsignedInt64 quotient = 0; \ SecUnsignedInt32 digit = 0; \ SecU64Div10((val64), &(quotient), &(digit)); \ --formatBuf.str; \ *(formatBuf.str) = digits[digit]; \ (val64) = quotient; \ } while ((val64) != 0) #else #define SECUREC_SPECIAL_QWORD_BASE10(val64) do { \ --formatBuf.str; \ *(formatBuf.str) = digits[(val64) % SECUREC_RADIX_DECIMAL]; \ } while (((val64) /= SECUREC_RADIX_DECIMAL) != 0) #endif #define SECUREC_SPECIAL_QWORD(val64, numBase) do { \ --formatBuf.str; \ *(formatBuf.str) = digits[(val64) % (numBase)]; \ } while (((val64) /= (numBase)) != 0) #define SECUREC_SAFE_WRITE_STR_OPT(src, txtLen, outStream, outChars) do { \ int ii_; \ for (ii_ = 0; ii_ < (txtLen); ++ii_) { \ *((SecChar *)(void *)((outStream)->cur)) = *(SecChar *)(src); \ (outStream)->cur += sizeof(SecChar); \ (src) = (src) + 1; \ } \ (outStream)->count -= (txtLen) * (int)(sizeof(SecChar)); \ *(outChars) = *(outChars) + (txtLen); \ } SECUREC_WHILE_ZERO #define SECUREC_SAFE_WRITE_STR(src, txtLen, outStream, outChars) do { \ if ((txtLen) < 12) { /* performance optimization for mobile number length 12 */ \ SECUREC_SAFE_WRITE_STR_OPT((src), (txtLen), (outStream), (outChars)); \ } else { \ SecDoMemcpy((outStream)->cur, (src), ((size_t)(unsigned int)(txtLen) * (sizeof(SecChar)))); \ (outStream)->cur += (size_t)((size_t)(unsigned int)(txtLen) * (sizeof(SecChar))); \ (outStream)->count -= (txtLen) * (int)(sizeof(SecChar)); \ *(outChars) = *(outChars) + (txtLen); \ } \ } SECUREC_WHILE_ZERO #define SECUREC_SAFE_WRITE_CHAR(c, outStream, outChars) do { \ *((SecChar *)(void *)((outStream)->cur)) = (SecChar)(c); \ (outStream)->cur += sizeof(SecChar); \ (outStream)->count -= (int)(sizeof(SecChar)); \ *(outChars) = *(outChars) + 1; \ } SECUREC_WHILE_ZERO #define SECUREC_SAFE_PADDING(padChar, padLen, outStream, outChars) do { \ int ii_; \ for (ii_ = 0; ii_ < (padLen); ++ii_) { \ *((SecChar *)(void *)((outStream)->cur)) = (SecChar)(padChar); \ (outStream)->cur += sizeof(SecChar); \ } \ (outStream)->count -= (padLen) * (int)(sizeof(SecChar)); \ *(outChars) = *(outChars) + (padLen); \ } SECUREC_WHILE_ZERO /* The count variable can be reduced to 0, and the external function complements the \0 terminator. */ #define SECUREC_IS_REST_BUF_ENOUGH(stream, needLen) ((int)((stream)->count - \ (int)(needLen) * (int)(sizeof(SecChar))) >= 0) #define SECUREC_FMT_STATE_OFFSET 256 #ifdef SECUREC_FOR_WCHAR #define SECUREC_FMT_TYPE(c, fmtTable) ((((unsigned int)(int)(c)) <= (unsigned int)(int)SECUREC_CHAR('~')) ? \ ((fmtTable)[(unsigned char)(c)]) : 0) #define SECUREC_DECODE_STATE(c, fmtTable, lastState) (SecFmtState)((((fmtTable)[(SECUREC_FMT_TYPE(c, (fmtTable))) * \ ((unsigned char)STAT_INVALID + 1) + \ (unsigned char)(lastState) + \ SECUREC_FMT_STATE_OFFSET]))) #else #define SECUREC_DECODE_STATE(c, fmtTable, lastState) (SecFmtState)(((fmtTable)[((fmtTable)[(unsigned char)(c)]) * \ ((unsigned char)STAT_INVALID + 1) + \ (unsigned char)(lastState) + \ SECUREC_FMT_STATE_OFFSET])) #endif static void SecDecodeFlags(SecChar ch, SecFormatAttr *attr) { switch (ch) { case SECUREC_CHAR(' '): attr->flags |= SECUREC_FLAG_SIGN_SPACE; break; case SECUREC_CHAR('+'): attr->flags |= SECUREC_FLAG_SIGN; break; case SECUREC_CHAR('-'): attr->flags |= SECUREC_FLAG_LEFT; break; case SECUREC_CHAR('0'): attr->flags |= SECUREC_FLAG_LEADZERO; /* add zero th the front */ break; case SECUREC_CHAR('#'): attr->flags |= SECUREC_FLAG_ALTERNATE; /* output %x with 0x */ break; default: break; } return; } /* * Decoded size identifier in format string to Reduce the number of lines of function code */ static int SecDecodeSizeI(SecFormatAttr *attr, const SecChar **format) { #ifdef SECUREC_ON_64BITS attr->flags |= SECUREC_FLAG_I64; /* %I to INT64 */ #endif if ((**format == SECUREC_CHAR('6')) && (*((*format) + 1) == SECUREC_CHAR('4'))) { (*format) += 2; /* add 2 to skip I64 */ attr->flags |= SECUREC_FLAG_I64; /* %I64 to INT64 */ } else if ((**format == SECUREC_CHAR('3')) && (*((*format) + 1) == SECUREC_CHAR('2'))) { (*format) += 2; /* add 2 to skip I32 */ attr->flags &= ~SECUREC_FLAG_I64; /* %I64 to INT32 */ } else if ((**format == SECUREC_CHAR('d')) || (**format == SECUREC_CHAR('i')) || (**format == SECUREC_CHAR('o')) || (**format == SECUREC_CHAR('u')) || (**format == SECUREC_CHAR('x')) || (**format == SECUREC_CHAR('X'))) { /* do nothing */ } else { /* Compatibility code for "%I" just print I */ return -1; } return 0; } /* * Decoded size identifier in format string */ static int SecDecodeSize(SecChar ch, SecFormatAttr *attr, const SecChar **format) { switch (ch) { #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT case SECUREC_CHAR('j'): attr->flags |= SECUREC_FLAG_INTMAX; break; #endif case SECUREC_CHAR('q'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('L'): attr->flags |= SECUREC_FLAG_LONGLONG | SECUREC_FLAG_LONG_DOUBLE; break; case SECUREC_CHAR('l'): if (**format == SECUREC_CHAR('l')) { *format = *format + 1; attr->flags |= SECUREC_FLAG_LONGLONG; /* long long */ } else { attr->flags |= SECUREC_FLAG_LONG; /* long int or wchar_t */ } break; case SECUREC_CHAR('t'): attr->flags |= SECUREC_FLAG_PTRDIFF; break; #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT case SECUREC_CHAR('z'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('Z'): attr->flags |= SECUREC_FLAG_SIZE; break; #endif case SECUREC_CHAR('I'): if (SecDecodeSizeI(attr, format) != 0) { /* Compatibility code for "%I" just print I */ return -1; } break; case SECUREC_CHAR('h'): if (**format == SECUREC_CHAR('h')) { attr->flags |= SECUREC_FLAG_CHAR; /* char */ } else { attr->flags |= SECUREC_FLAG_SHORT; /* short int */ } break; case SECUREC_CHAR('w'): attr->flags |= SECUREC_FLAG_WIDECHAR; /* wide char */ break; default: break; } return 0; } /* * Decoded char type identifier */ static int SecDecodeTypeC(SecFormatAttr *attr, unsigned int cValue, SecFormatBuf *formatBuf, SecBuffer *buffer) { #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) && !(defined(__hpux)) && !(defined(SECUREC_ON_SOLARIS)) attr->flags &= ~SECUREC_FLAG_LEADZERO; #endif #ifdef SECUREC_FOR_WCHAR attr->bufferIsWide = 1; if (attr->flags & SECUREC_FLAG_SHORT) { #if SECUREC_HAVE_MBTOWC /* multibyte character to wide character */ char tmpChar[2]; /* One character string, length is 2 */ tmpChar[0] = (char)(cValue & 0x00ff); tmpChar[1] = '\0'; if (mbtowc(buffer->wStr, tmpChar, sizeof(tmpChar)) < 0) { return -1; } #else return -1; #endif } else { buffer->wStr[0] = (wchar_t)cValue; } formatBuf->wStr = buffer->wStr; return 1; /* only 1 wide character */ #else /* SECUREC_FOR_WCHAR */ attr->bufferIsWide = 0; if (attr->flags & (SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR)) { #if SECUREC_HAVE_WCTOMB wchar_t wChar = (wchar_t)cValue; int textLen; /* wide character to multibyte character */ SECUREC_MASK_MSVC_CRT_WARNING textLen = wctomb(buffer->str, wChar); SECUREC_END_MASK_MSVC_CRT_WARNING if (textLen < 0) { return -1; } formatBuf->str = buffer->str; return textLen; #else return -1; #endif } else { /* get multibyte character from argument */ unsigned short temp; temp = (unsigned short)cValue; buffer->str[0] = (char)temp; formatBuf->str = buffer->str; return 1; /* only 1 character */ } #endif } /* literal string to print null ptr, define it as array rather than const text area * is to avoid gcc warning with pointing const text with variable */ #if SECUREC_HAVE_WCHART static wchar_t g_wStrNullString[SECUREC_NULL_STRING_SIZE] = { L'(', L'n', L'u', L'l', L'l', L')', L'\0', L'\0' }; #endif static char g_strNullString[SECUREC_NULL_STRING_SIZE] = "(null)"; static int SecDecodeTypeSchar(const SecFormatAttr *attr, SecFormatBuf *formatBuf) { int finalPrecision = (attr->precision == -1) ? SECUREC_INT_MAX : attr->precision; int textLen; if (formatBuf->str == NULL) { /* NULL passed, use special string */ formatBuf->str = g_strNullString; } if (finalPrecision == SECUREC_INT_MAX) { /* precision NOT assigned */ /* The strlen performance is high when the string length is greater than 32 */ textLen = (int)strlen(formatBuf->str); } else { /* precision assigned */ size_t tmpLen; SECUREC_CALC_STR_LEN(formatBuf->str, (size_t)(unsigned int)finalPrecision, &tmpLen); textLen = (int)tmpLen; } return textLen; } #if SECUREC_HAVE_WCHART static int SecDecodeTypeSwchar(SecFormatAttr *attr, SecFormatBuf *formatBuf) { int finalPrecision = (attr->precision == -1) ? SECUREC_INT_MAX : attr->precision; int textLen; attr->bufferIsWide = 1; if (formatBuf->wStr == NULL) { /* NULL passed, use special string */ formatBuf->wStr = g_wStrNullString; } /* textLen in wchar_t */ SECUREC_CALC_WSTR_LEN(formatBuf->wStr, finalPrecision, &textLen); return textLen; } #endif /* * Decoded string identifier */ static int SecDecodeTypeS(SecFormatAttr *attr, char *argPtr, SecFormatBuf *formatBuf) { int textLen; #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) && (!defined(SECUREC_ON_UNIX)) attr->flags &= ~SECUREC_FLAG_LEADZERO; #endif formatBuf->str = argPtr; #ifdef SECUREC_FOR_WCHAR #if defined(SECUREC_COMPATIBLE_LINUX_FORMAT) if (!(attr->flags & SECUREC_FLAG_LONG)) { attr->flags |= SECUREC_FLAG_SHORT; } #endif if (attr->flags & SECUREC_FLAG_SHORT) { /* textLen now contains length in multibyte chars */ textLen = SecDecodeTypeSchar(attr, formatBuf); } else { /* textLen now contains length in wide chars */ textLen = SecDecodeTypeSwchar(attr, formatBuf); } #else /* SECUREC_FOR_WCHAR */ if (attr->flags & (SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR)) { /* textLen now contains length in wide chars */ #if SECUREC_HAVE_WCHART textLen = SecDecodeTypeSwchar(attr, formatBuf); #else textLen = 0; #endif } else { /* textLen now contains length in multibyte chars */ textLen = SecDecodeTypeSchar(attr, formatBuf); } #endif /* SECUREC_FOR_WCHAR */ return textLen; } /* * Write one character to dest buffer */ static void SecOutputOneChar(SecChar ch, SecPrintfStream *stream, int *counter) { /* normal state, write character */ if (SECUREC_IS_REST_BUF_ENOUGH(stream, 1)) { /* only one char */ SECUREC_SAFE_WRITE_CHAR(ch, stream, counter); /* char * cast to wchar * */ } else { #ifdef SECUREC_FOR_WCHAR SecWriteCharW(ch, stream, counter); #else /* optimize function call to code */ *counter = -1; stream->count = -1; #endif } } /* * Check precison in format */ static int SecDecodePrecision(SecChar ch, SecFormatAttr *formatAttr) { if (formatAttr->dynPrecision == 0) { /* add digit to current precision */ if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(formatAttr->precision)) { return -1; } formatAttr->precision = (int)SECUREC_MUL_TEN((unsigned int)formatAttr->precision) + (unsigned char)(ch - SECUREC_CHAR('0')); } else { if (formatAttr->precision < 0) { formatAttr->precision = -1; } if (formatAttr->precision > SECUREC_MAX_WIDTH_LEN) { return -1; } } return 0; } /* * Check width in format */ static int SecDecodeWidth(SecChar ch, SecFormatAttr *formatAttr, SecFmtState lastState) { if (formatAttr->dynWidth == 0) { if (lastState != STAT_WIDTH) { formatAttr->fldWidth = 0; } if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(formatAttr->fldWidth)) { return -1; } formatAttr->fldWidth = (int)SECUREC_MUL_TEN((unsigned int)formatAttr->fldWidth) + (unsigned char)(ch - SECUREC_CHAR('0')); } else { if (formatAttr->fldWidth < 0) { formatAttr->flags |= SECUREC_FLAG_LEFT; formatAttr->fldWidth = (-formatAttr->fldWidth); if (formatAttr->fldWidth > SECUREC_MAX_WIDTH_LEN) { return -1; } } } return 0; } #ifdef SECUREC_FOR_WCHAR /* * Formatting output core functions for wchar version.Called by a function such as vswprintf_s * argList must not be declare as const */ static int SecOutputSW(SecPrintfStream *stream, const wchar_t *cFormat, va_list argList) #else /* * Formatting output core functions for char version.Called by a function such as vsnprintf_s */ static int SecOutputS(SecPrintfStream *stream, const char *cFormat, va_list argList) #endif { const SecChar *format = cFormat; #if SECUREC_ENABLE_SPRINTF_FLOAT char *floatBuf = NULL; #endif SecFormatBuf formatBuf; static const char *itoaUpperDigits = "0123456789ABCDEFX"; static const char *itoaLowerDigits = "0123456789abcdefx"; const char *digits = itoaUpperDigits; unsigned int radix = SECUREC_RADIX_DECIMAL; int charsOut; /* characters written */ int prefixLen = 0; /* Must be initialized or compiler alerts */ int padding = 0; int textLen; /* length of the text */ int noOutput = 0; /* Must be initialized or compiler alerts */ SecFmtState state; SecFmtState lastState; SecChar prefix[SECUREC_PREFIX_LEN] = { 0 }; SecChar ch; /* currently read character */ static const unsigned char stateTable[SECUREC_STATE_TABLE_SIZE] = { /* type 0: nospecial meanin; * 1: '%'; * 2: '.' * 3: '*' * 4: '0' * 5: '1' ... '9' * 6: ' ', '+', '-', '#' * 7: 'h', 'l', 'L', 'F', 'w' , 'N','z','q','t','j' * 8: 'd','o','u','i','x','X','e','f','g' */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, 0x06, 0x00, 0x06, 0x02, 0x00, 0x04, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x08, 0x08, 0x00, 0x07, 0x00, 0x00, 0x07, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x07, 0x08, 0x07, 0x00, 0x07, 0x00, 0x00, 0x08, 0x08, 0x07, 0x00, 0x08, 0x07, 0x08, 0x00, 0x07, 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, /* fill zero for normal char 128 byte for 0x80 - 0xff */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, /* state 0: normal * 1: percent * 2: flag * 3: width * 4: dot * 5: precis * 6: size * 7: type * 8: invalid */ 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x01, 0x00, 0x00, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x00, 0x00, 0x00, 0x03, 0x03, 0x08, 0x05, 0x08, 0x08, 0x00, 0x00, 0x00, 0x02, 0x02, 0x03, 0x05, 0x05, 0x08, 0x00, 0x00, 0x00, 0x03, 0x03, 0x03, 0x05, 0x05, 0x08, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x08, 0x08, 0x08, 0x00, 0x00, 0x00, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x00, 0x00, 0x00, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x00, 0x00 }; SecFormatAttr formatAttr; SecBuffer buffer; formatAttr.flags = 0; formatAttr.bufferIsWide = 0; /* flag for buffer contains wide chars */ formatAttr.fldWidth = 0; formatAttr.precision = 0; formatAttr.dynWidth = 0; formatAttr.dynPrecision = 0; charsOut = 0; textLen = 0; state = STAT_NORMAL; /* starting state */ formatBuf.str = NULL; /* loop each format character */ /* remove format != NULL */ while ((ch = *format) != SECUREC_CHAR('\0') && charsOut >= 0) { ++format; lastState = state; state = SECUREC_DECODE_STATE(ch, stateTable, lastState); switch (state) { case STAT_NORMAL: SecOutputOneChar(ch, stream, &charsOut); continue; case STAT_PERCENT: /* set default values */ prefixLen = 0; noOutput = 0; formatAttr.flags = 0; formatAttr.fldWidth = 0; formatAttr.precision = -1; formatAttr.bufferIsWide = 0; formatAttr.dynWidth = 0; formatAttr.dynPrecision = 0; break; case STAT_FLAG: /* set flag based on which flag character */ SecDecodeFlags(ch, &formatAttr); break; case STAT_WIDTH: /* update width value */ if (ch == SECUREC_CHAR('*')) { /* get width */ formatAttr.fldWidth = (int)va_arg(argList, int); formatAttr.dynWidth = 1; } else { formatAttr.dynWidth = 0; } if (SecDecodeWidth(ch, &formatAttr, lastState) != 0) { return -1; } break; case STAT_DOT: formatAttr.precision = 0; break; case STAT_PRECIS: /* update precison value */ if (ch == SECUREC_CHAR('*')) { /* get precision from arg list */ formatAttr.precision = (int)va_arg(argList, int); formatAttr.dynPrecision = 1; } else { formatAttr.dynPrecision = 0; } if (SecDecodePrecision(ch, &formatAttr) != 0) { return -1; } break; case STAT_SIZE: /* read a size specifier, set the formatAttr.flags based on it */ if (SecDecodeSize(ch, &formatAttr, &format) != 0) { /* Compatibility code for "%I" just print I */ SecOutputOneChar(ch, stream, &charsOut); state = STAT_NORMAL; continue; } break; case STAT_TYPE: switch (ch) { case SECUREC_CHAR('C'): /* wide char */ if (!(formatAttr.flags & (SECUREC_FLAG_SHORT | SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR))) { #ifdef SECUREC_FOR_WCHAR formatAttr.flags |= SECUREC_FLAG_SHORT; #else formatAttr.flags |= SECUREC_FLAG_WIDECHAR; #endif } /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('c'): do { unsigned int cValue = (unsigned int)va_arg(argList, int); textLen = SecDecodeTypeC(&formatAttr, cValue, &formatBuf, &buffer); if (textLen < 0) { noOutput = 1; } } SECUREC_WHILE_ZERO; break; case SECUREC_CHAR('S'): /* wide char string */ if (!(formatAttr.flags & (SECUREC_FLAG_SHORT | SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR))) { #ifndef SECUREC_FOR_WCHAR formatAttr.flags |= SECUREC_FLAG_WIDECHAR; #else formatAttr.flags |= SECUREC_FLAG_SHORT; #endif } /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('s'): do { char *argPtr = (char *)va_arg(argList, char *); textLen = SecDecodeTypeS(&formatAttr, argPtr, &formatBuf); } SECUREC_WHILE_ZERO; break; case SECUREC_CHAR('n'): /* higher risk disable it */ return -1; case SECUREC_CHAR('E'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('F'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('G'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('A'): /* fall-through */ /* FALLTHRU */ /* convert format char to lower , use Explicit conversion to clean up compilation warning */ ch = (SecChar)(ch + ((SecChar)(SECUREC_CHAR('a')) - (SECUREC_CHAR('A')))); /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('e'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('f'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('g'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('a'): #if SECUREC_ENABLE_SPRINTF_FLOAT do { int bufferSize = 0; /* size of formatBuf.str */ /* floating point conversion */ formatBuf.str = buffer.str; /* output buffer for float string with default size */ /* compute the precision value */ if (formatAttr.precision < 0) { formatAttr.precision = SECUREC_FLOAT_DEFAULT_PRECISION; } else if (formatAttr.precision == 0 && ch == SECUREC_CHAR('g')) { formatAttr.precision = 1; } /* calc buffer size to store double value * The maximum length of SECUREC_MAX_WIDTH_LEN is enough */ if (formatAttr.flags & SECUREC_FLAG_LONG_DOUBLE) { if (formatAttr.precision > (SECUREC_MAX_WIDTH_LEN - SECUREC_FLOAT_BUFSIZE_LB)) { noOutput = 1; break; } /* Long double needs to meet the basic print length */ bufferSize = SECUREC_FLOAT_BUFSIZE_LB + formatAttr.precision + SECUREC_FLOAT_BUF_EXT; } else { if (formatAttr.precision > (SECUREC_MAX_WIDTH_LEN - SECUREC_FLOAT_BUFSIZE)) { noOutput = 1; break; } /* Double needs to meet the basic print length */ bufferSize = SECUREC_FLOAT_BUFSIZE + formatAttr.precision + SECUREC_FLOAT_BUF_EXT; } if (formatAttr.fldWidth > bufferSize) { bufferSize = formatAttr.fldWidth + SECUREC_FLOAT_BUF_EXT; } if (bufferSize > SECUREC_BUFFER_SIZE) { /* the current vlaue of SECUREC_BUFFER_SIZE could NOT store the * formatted float string */ floatBuf = (char *)SECUREC_MALLOC(((size_t)(unsigned int)bufferSize)); if (floatBuf != NULL) { formatBuf.str = floatBuf; } else { noOutput = 1; break; } } do { /* add following code to call system sprintf API for float number */ const SecChar *pFloatFmt = format - 2; /* sub 2 to the position before 'f' or 'g' */ int k; int fFmtStrLen; char fFmtBuf[SECUREC_FMT_STR_LEN]; char *fFmtStr = fFmtBuf; char *fFmtHeap = NULL; /* to clear warning */ while (SECUREC_CHAR('%') != *pFloatFmt) { /* must meet '%' */ --pFloatFmt; } fFmtStrLen = (int)((format - pFloatFmt) + 1); /* with ending terminator */ if (fFmtStrLen > SECUREC_FMT_STR_LEN) { /* if SECUREC_FMT_STR_LEN is NOT enough, alloc a new buffer */ fFmtHeap = (char *)SECUREC_MALLOC((size_t)((unsigned int)fFmtStrLen)); if (fFmtHeap == NULL) { noOutput = 1; break; } else { for (k = 0; k < fFmtStrLen - 1; ++k) { /* convert wchar to char */ fFmtHeap[k] = (char)(pFloatFmt[k]); /* copy the format string */ } fFmtHeap[k] = '\0'; fFmtStr = fFmtHeap; } } else { /* purpose of the repeat code is to solve the tool alarm Redundant_Null_Check */ for (k = 0; k < fFmtStrLen - 1; ++k) { /* convert wchar to char */ fFmtBuf[k] = (char)(pFloatFmt[k]); /* copy the format string */ } fFmtBuf[k] = '\0'; } if (formatAttr.flags & SECUREC_FLAG_LONG_DOUBLE) { #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT long double tmp = (long double)va_arg(argList, long double); textLen = SecFormatLongDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); #else double tmp = (double)va_arg(argList, double); textLen = SecFormatDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); #endif } else { double tmp = (double)va_arg(argList, double); textLen = SecFormatDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); } if (fFmtHeap != NULL) { /* if buffer is alloced on heap, free it */ SECUREC_FREE(fFmtHeap); fFmtHeap = NULL; /* to clear e438 last value assigned not used , the compiler will * optimize this code */ (void)fFmtHeap; } if (textLen < 0 || textLen >= bufferSize) { /* bufferSize is large enough, just validation the return value */ noOutput = 1; break; } /* no padding ,this variable to calculate amount of padding */ formatAttr.fldWidth = textLen; prefixLen = 0; /* no padding ,this variable to calculate amount of padding */ formatAttr.flags = 0; /* clear all internal formatAttr.flags */ break; } SECUREC_WHILE_ZERO; } SECUREC_WHILE_ZERO; break; #else return -1; #endif case SECUREC_CHAR('p'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('X'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('x'): /* unsigned lower hex output */ digits = itoaLowerDigits; radix = SECUREC_RADIX_HEX; switch (ch) { case SECUREC_CHAR('p'): /* print a pointer */ #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; #else formatAttr.flags |= SECUREC_FLAG_POINTER; #endif #ifdef SECUREC_ON_64BITS formatAttr.flags |= SECUREC_FLAG_I64; /* converting an int64 */ #else formatAttr.flags |= SECUREC_FLAG_LONG; /* converting a long */ #endif #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) && (!defined(SECUREC_ON_UNIX)) #if defined(SECUREC_VXWORKS_PLATFORM) formatAttr.precision = 1; #else formatAttr.precision = 0; #endif formatAttr.flags |= SECUREC_FLAG_ALTERNATE; /* "0x" is not default prefix in UNIX */ break; #else /* not linux vxwoks */ #if defined(_AIX) || defined(SECUREC_ON_SOLARIS) formatAttr.precision = 1; #else formatAttr.precision = 2 * sizeof(void *); /* 2 precision of different systems */ #endif #endif #if defined(SECUREC_ON_UNIX) break; #endif /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('X'): /* fall-through */ /* FALLTHRU */ /* unsigned upper hex output */ digits = itoaUpperDigits; break; default: break; } if (formatAttr.flags & SECUREC_FLAG_ALTERNATE) { /* alternate form means '0x' prefix */ prefix[0] = SECUREC_CHAR('0'); prefix[1] = (SecChar)(digits[16]); /* 16 for 'x' or 'X' */ #if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) if (ch == 'p') { prefix[1] = SECUREC_CHAR('x'); } #endif #if defined(_AIX) || defined(SECUREC_ON_SOLARIS) if (ch == 'p') { prefixLen = 0; } else { prefixLen = SECUREC_PREFIX_LEN; } #else prefixLen = SECUREC_PREFIX_LEN; #endif } /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('u'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('o'): /* fall-through */ /* FALLTHRU */ switch (ch) { case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ /* signed decimal output */ formatAttr.flags |= SECUREC_FLAG_SIGNED; /* fall-through */ /* FALLTHRU */ case SECUREC_CHAR('u'): radix = SECUREC_RADIX_DECIMAL; break; case SECUREC_CHAR('o'): /* unsigned octal output */ radix = SECUREC_RADIX_OCTAL; if (formatAttr.flags & SECUREC_FLAG_ALTERNATE) { /* alternate form means force a leading 0 */ formatAttr.flags |= SECUREC_FLAG_FORCE_OCTAL; } break; default: break; } do { SecUnsignedInt64 number = 0; /* number to convert */ SecInt64 l; /* temp long value */ /* read argument into variable l */ if (formatAttr.flags & SECUREC_FLAG_I64) { l = (SecInt64)va_arg(argList, SecInt64); } else if (formatAttr.flags & SECUREC_FLAG_LONGLONG) { l = (SecInt64)va_arg(argList, SecInt64); } else #ifdef SECUREC_ON_64BITS if (formatAttr.flags & SECUREC_FLAG_LONG) { l = (long)va_arg(argList, long); } else #endif /* SECUREC_ON_64BITS */ if (formatAttr.flags & SECUREC_FLAG_CHAR) { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { l = (char)va_arg(argList, int); /* sign extend */ if (l >= 128) { /* 128 on some platform, char is always unsigned */ SecUnsignedInt64 tmpL = (SecUnsignedInt64)l; unsigned char tmpCh = (unsigned char)(~(tmpL)); l = tmpCh + 1; formatAttr.flags |= SECUREC_FLAG_NEGATIVE; } } else { l = (unsigned char)va_arg(argList, int); /* zero-extend */ } } else if (formatAttr.flags & SECUREC_FLAG_SHORT) { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { l = (short)va_arg(argList, int); /* sign extend */ } else { l = (unsigned short)va_arg(argList, int); /* zero-extend */ } } #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT else if (formatAttr.flags & SECUREC_FLAG_PTRDIFF) { l = (ptrdiff_t)va_arg(argList, ptrdiff_t); /* sign extend */ } else if (formatAttr.flags & SECUREC_FLAG_SIZE) { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { /* No suitable macros were found to handle the branch */ if (SecIsSameSize(sizeof(size_t), sizeof(long))) { l = va_arg(argList, long); /* sign extend */ } else if (SecIsSameSize(sizeof(size_t), sizeof(long long))) { l = va_arg(argList, long long); /* sign extend */ } else { l = va_arg(argList, int); /* sign extend */ } } else { l = (SecInt64)(size_t)va_arg(argList, size_t); /* sign extend */ } } else if (formatAttr.flags & SECUREC_FLAG_INTMAX) { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { l = va_arg(argList, SecInt64); /* sign extend */ } else { /* sign extend */ l = (SecInt64)(SecUnsignedInt64)va_arg(argList, SecUnsignedInt64); } } #endif else { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { l = va_arg(argList, int); /* sign extend */ } else { l = (unsigned int)va_arg(argList, int); /* zero-extend */ } } /* check for negative; copy into number */ if ((formatAttr.flags & SECUREC_FLAG_SIGNED) && l < 0) { number = (SecUnsignedInt64)(-l); formatAttr.flags |= SECUREC_FLAG_NEGATIVE; } else { number = (SecUnsignedInt64)l; } if (((formatAttr.flags & SECUREC_FLAG_I64) == 0) && #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT ((formatAttr.flags & SECUREC_FLAG_INTMAX) == 0) && #endif #ifdef SECUREC_ON_64BITS ((formatAttr.flags & SECUREC_FLAG_PTRDIFF) == 0) && ((formatAttr.flags & SECUREC_FLAG_SIZE) == 0) && #if !defined(SECUREC_COMPATIBLE_WIN_FORMAT) /* on window 64 system sizeof long is 32bit */ ((formatAttr.flags & SECUREC_FLAG_LONG) == 0) && #endif #endif ((formatAttr.flags & SECUREC_FLAG_LONGLONG) == 0)) { number &= 0xffffffff; /* use 0xffffffff as 32 bit mask */ } /* check precision value for default */ if (formatAttr.precision < 0) { formatAttr.precision = 1; /* default precision */ } else { #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; #else if (!(formatAttr.flags & SECUREC_FLAG_POINTER)) { formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; } #endif if (formatAttr.precision > SECUREC_MAX_PRECISION) { formatAttr.precision = SECUREC_MAX_PRECISION; } } /* Check if data is 0; if so, turn off hex prefix, * 'p' add 0x prefix, otherwise not add prefix */ if (number == 0) { #if !(defined(SECUREC_VXWORKS_PLATFORM) || defined(__hpux)) prefixLen = 0; #else if ((ch == 'p') && (formatAttr.flags & SECUREC_FLAG_ALTERNATE)) { prefixLen = SECUREC_PREFIX_LEN; } else { prefixLen = 0; } #endif } /* Convert data to ASCII */ formatBuf.str = &buffer.str[SECUREC_BUFFER_SIZE]; if (number > 0) { #ifdef SECUREC_ON_64BITS switch (radix) { /* the compiler will optimize each one */ case SECUREC_RADIX_DECIMAL: SECUREC_SPECIAL_QWORD_BASE10(number); break; case SECUREC_RADIX_HEX: SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_HEX); break; case SECUREC_RADIX_OCTAL: SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_OCTAL); break; default: break; } #else /* for 32 bits system */ if (number <= 0xFFFFFFFFUL) { /* in most case, the value to be converted is small value */ SecUnsignedInt32 n32Tmp = (SecUnsignedInt32)number; switch (radix) { case SECUREC_RADIX_HEX: SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_HEX); break; case SECUREC_RADIX_OCTAL: SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_OCTAL); break; #ifdef _AIX /* the compiler will optimize div 10 */ case SECUREC_RADIX_DECIMAL: SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_DECIMAL); break; #else case SECUREC_RADIX_DECIMAL: do { /* fast div 10 */ SecUnsignedInt32 q; SecUnsignedInt32 r; do { *--formatBuf.str = digits[n32Tmp % SECUREC_RADIX_DECIMAL]; q = (n32Tmp >> 1) + (n32Tmp >> 2); /* fast div magic 2 */ q = q + (q >> 4); /* fast div magic 4 */ q = q + (q >> 8); /* fast div magic 8 */ q = q + (q >> 16); /* fast div magic 16 */ q = q >> 3; /* fast div magic 3 */ r = n32Tmp - SECUREC_MUL_TEN(q); n32Tmp = (r > 9) ? (q + 1) : q; /* fast div magic 9 */ } while (n32Tmp != 0); } SECUREC_WHILE_ZERO; break; #endif default: break; } /* end switch */ } else { /* the value to be converted is greater than 4G */ #if defined(SECUREC_VXWORKS_VERSION_5_4) do { SecUnsignedInt32 digit = 0; /* ascii value of digit */ SecUnsignedInt64 quotient = 0; if (SecU64Div32(number,(SecUnsignedInt32)radix, "ient, &digit) != 0) { noOutput = 1; break; } *--formatBuf.str = digits[digit]; number = quotient; } while (number != 0); #else switch (radix) { /* the compiler will optimize div 10 */ case SECUREC_RADIX_DECIMAL: SECUREC_SPECIAL_QWORD_BASE10(number); break; case SECUREC_RADIX_OCTAL: SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_OCTAL); break; case SECUREC_RADIX_HEX: SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_HEX); break; default: break; } #endif } #endif } /* compute length of number,.if textLen > 0, then formatBuf.str must be in buffer.str */ textLen = (int)(size_t)((char *)&buffer.str[SECUREC_BUFFER_SIZE] - formatBuf.str); if (formatAttr.precision > textLen) { int ii; for (ii = 0; ii < formatAttr.precision - textLen; ++ii) { *--formatBuf.str = '0'; } textLen = formatAttr.precision; } /* Force a leading zero if FORCEOCTAL flag set */ if ((formatAttr.flags & SECUREC_FLAG_FORCE_OCTAL) && (textLen == 0 || formatBuf.str[0] != '0')) { *--formatBuf.str = '0'; ++textLen; /* add a zero */ } } SECUREC_WHILE_ZERO; break; default: break; } while (noOutput < 1) { if (formatAttr.flags & SECUREC_FLAG_SIGNED) { if (formatAttr.flags & SECUREC_FLAG_NEGATIVE) { /* prefix is a '-' */ prefix[0] = SECUREC_CHAR('-'); prefixLen = 1; } else if (formatAttr.flags & SECUREC_FLAG_SIGN) { /* prefix is '+' */ prefix[0] = SECUREC_CHAR('+'); prefixLen = 1; } else if (formatAttr.flags & SECUREC_FLAG_SIGN_SPACE) { /* prefix is ' ' */ prefix[0] = SECUREC_CHAR(' '); prefixLen = 1; } } #if defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && (!defined(SECUREC_ON_UNIX)) if ((formatAttr.flags & SECUREC_FLAG_POINTER) && (textLen == 0)) { formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; formatBuf.str = &buffer.str[SECUREC_BUFFER_SIZE - 1]; *formatBuf.str-- = '\0'; *formatBuf.str-- = ')'; *formatBuf.str-- = 'l'; *formatBuf.str-- = 'i'; *formatBuf.str-- = 'n'; *formatBuf.str = '('; textLen = 5; /* length of (nil) is 5 */ } #endif /* calculate amount of padding */ padding = (formatAttr.fldWidth - textLen) - prefixLen; /* put out the padding, prefix, and text, in the correct order */ if (!(formatAttr.flags & (SECUREC_FLAG_LEFT | SECUREC_FLAG_LEADZERO)) && padding > 0) { /* pad on left with blanks */ if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { /* char * cast to wchar * */ SECUREC_SAFE_PADDING(SECUREC_CHAR(' '), padding, stream, &charsOut); } else { SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR(' '), padding, stream, &charsOut); } } /* write prefix */ if (prefixLen > 0) { SecChar *pPrefix = prefix; if (SECUREC_IS_REST_BUF_ENOUGH(stream, prefixLen)) { /* max prefix len is 2, use loop copy */ /* char * cast to wchar * in WCHAR version */ SECUREC_SAFE_WRITE_STR_OPT(pPrefix, prefixLen, stream, &charsOut); } else { SECUREC_WRITE_STRING(prefix, prefixLen, stream, &charsOut); } } if ((formatAttr.flags & SECUREC_FLAG_LEADZERO) && !(formatAttr.flags & SECUREC_FLAG_LEFT) && padding > 0) { /* write leading zeros */ if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { /* char * cast to wchar * */ SECUREC_SAFE_PADDING(SECUREC_CHAR('0'), padding, stream, &charsOut); } else { SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR('0'), padding, stream, &charsOut); } } /* write text */ #ifndef SECUREC_FOR_WCHAR if (formatAttr.bufferIsWide != 0 && (textLen > 0)) { #if SECUREC_HAVE_WCTOMB wchar_t *p = formatBuf.wStr; int count = textLen; while (count > 0) { char tmpBuf[SECUREC_MB_LEN + 1]; SECUREC_MASK_MSVC_CRT_WARNING int retVal = wctomb(tmpBuf, *p); SECUREC_END_MASK_MSVC_CRT_WARNING if (retVal <= 0) { charsOut = -1; break; } SECUREC_WRITE_STRING(tmpBuf, retVal, stream, &charsOut); --count; ++p; } #else charsOut = -1; break; #endif } else { if (SECUREC_IS_REST_BUF_ENOUGH(stream, textLen)) { SECUREC_SAFE_WRITE_STR(formatBuf.str, textLen, stream, &charsOut); } else { SECUREC_WRITE_STRING(formatBuf.str, textLen, stream, &charsOut); } } #else /* SECUREC_FOR_WCHAR */ if (formatAttr.bufferIsWide == 0 && textLen > 0) { #if SECUREC_HAVE_MBTOWC int count = textLen; char *p = formatBuf.str; while (count > 0) { wchar_t wChar = L'\0'; int retVal = mbtowc(&wChar, p, (size_t)MB_CUR_MAX); if (retVal <= 0) { charsOut = -1; break; } SecWriteCharW(wChar, stream, &charsOut); p += retVal; count -= retVal; } #else charsOut = -1; break; #endif } else { if (SECUREC_IS_REST_BUF_ENOUGH(stream, textLen)) { /* char * cast to wchar * */ SECUREC_SAFE_WRITE_STR(formatBuf.wStr, textLen, stream, &charsOut); } else { SECUREC_WRITE_STRING(formatBuf.wStr, textLen, stream, &charsOut); } } #endif /* SECUREC_FOR_WCHAR */ if (charsOut >= 0 && (formatAttr.flags & SECUREC_FLAG_LEFT) && padding > 0) { /* pad on right with blanks */ if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { /* char * cast to wchar * */ SECUREC_SAFE_PADDING(SECUREC_CHAR(' '), padding, stream, &charsOut); } else { SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR(' '), padding, stream, &charsOut); } } break; } #if SECUREC_ENABLE_SPRINTF_FLOAT if (floatBuf != NULL) { SECUREC_FREE(floatBuf); floatBuf = NULL; } #endif break; case STAT_INVALID: return -1; default: return -1; /* input format is wrong, directly return */ } } if (state != STAT_NORMAL && state != STAT_TYPE) { return -1; } return charsOut; /* the number of characters written */ } #endif /* OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 */ ================================================ FILE: third_party/securec/src/scanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The scanf_s function is equivalent to fscanf_s with the argument stdin interposed before the arguments to scanf_s * The scanf_s function reads data from the standard input stream stdin and * writes the data into the location that's given by argument. Each argument * must be a pointer to a variable of a type that corresponds to a type specifier * in format. If copying occurs between strings that overlap, the behavior is * undefined. * * * format Format control string. * ... Optional arguments. * * * ... The converted value stored in user assigned address * * * Returns the number of fields successfully converted and assigned; * the return value does not include fields that were read but not assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int scanf_s(const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vscanf_s(format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } ================================================ FILE: third_party/securec/src/secinput.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef SEC_INPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C #define SEC_INPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C #include "securecutil.h" #define SECUREC_SCANF_EINVAL (-1) #define SECUREC_SCANF_ERROR_PARA (-2) /* for internal stream flag */ #define SECUREC_MEM_STR_FLAG 0X01 #define SECUREC_FILE_STREAM_FLAG 0X02 #define SECUREC_FROM_STDIN_FLAG 0X04 #define SECUREC_LOAD_FILE_TO_MEM_FLAG 0X08 #define SECUREC_UNINITIALIZED_FILE_POS (-1) #define SECUREC_BOM_HEADER_SIZE 2 #define SECUREC_BOM_HEADER_BE_1ST 0xFEU #define SECUREC_BOM_HEADER_BE_2ST 0xFFU #define SECUREC_BOM_HEADER_LE_1ST 0xFFU #define SECUREC_BOM_HEADER_LE_2ST 0xFEU #define SECUREC_UTF8_BOM_HEADER_SIZE 3 #define SECUREC_UTF8_BOM_HEADER_1ST 0xEFU #define SECUREC_UTF8_BOM_HEADER_2ND 0xBBU #define SECUREC_UTF8_BOM_HEADER_3RD 0xBFU #define SECUREC_UTF8_LEAD_1ST 0xE0 #define SECUREC_UTF8_LEAD_2ND 0x80 typedef struct { unsigned int flag; /* mark the properties of input stream */ int count; /* the size of buffered string in bytes */ const char *cur; /* the pointer to next read position */ char *base; /* the pointer to the header of buffered string */ #if SECUREC_ENABLE_SCANF_FILE FILE *pf; /* the file pointer */ long oriFilePos; /* the original position of file offset when fscanf is called */ int fileRealRead; #if defined(SECUREC_NO_STD_UNGETC) unsigned int lastChar; /* the char code of last input */ int fUnget; /* the boolean flag of pushing a char back to read stream */ #endif #endif } SecFileStream; #define SECUREC_INIT_SEC_FILE_STREAM_COMMON(fileStream, streamFlag, curPtr, strCount) do { \ (fileStream).flag = (streamFlag); \ (fileStream).count = (strCount); \ (fileStream).cur = (curPtr); \ (fileStream).base = NULL; \ } SECUREC_WHILE_ZERO #if SECUREC_ENABLE_SCANF_FILE #if defined(SECUREC_NO_STD_UNGETC) /* This initialization for eliminating redundant initialization. * Compared with the previous version initialization 0, * the current code causes the binary size to increase by some bytes */ #define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ (fileStream).pf = (stream); \ (fileStream).oriFilePos = (filePos); \ (fileStream).fileRealRead = 0; \ (fileStream).lastChar = 0; \ (fileStream).fUnget = 0; \ } SECUREC_WHILE_ZERO #else #define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ (fileStream).pf = (stream); \ (fileStream).oriFilePos = (filePos); \ (fileStream).fileRealRead = 0; \ } SECUREC_WHILE_ZERO #endif #else /* No SECUREC_ENABLE_SCANF_FILE */ #define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ } SECUREC_WHILE_ZERO #endif #ifdef __cplusplus extern "C" { #endif extern int SecInputS(SecFileStream *stream, const char *cFormat, va_list argList); extern void SecClearDestBuf(const char *buffer, const char *format, va_list argList); #if SECUREC_IN_KERNEL == 0 extern int SecInputSW(SecFileStream *stream, const wchar_t *cFormat, va_list argList); extern void SecClearDestBufW(const wchar_t *buffer, const wchar_t *format, va_list argList); #endif /* 20150105 For software and hardware decoupling,such as UMG */ #if defined(SECUREC_SYSAPI4VXWORKS) #ifdef feof #undef feof #endif extern int feof(FILE *stream); #endif #if defined(SECUREC_SYSAPI4VXWORKS) || defined(SECUREC_CTYPE_MACRO_ADAPT) #ifndef isspace #define isspace(c) (((c) == ' ') || ((c) == '\t') || ((c) == '\r') || ((c) == '\n')) #endif #ifndef iswspace #define iswspace(c) (((c) == L' ') || ((c) == L'\t') || ((c) == L'\r') || ((c) == L'\n')) #endif #ifndef isascii #define isascii(c) (((unsigned char)(c)) <= 0x7f) #endif #ifndef isupper #define isupper(c) ((c) >= 'A' && (c) <= 'Z') #endif #ifndef islower #define islower(c) ((c) >= 'a' && (c) <= 'z') #endif #ifndef isalpha #define isalpha(c) (isupper(c) || (islower(c))) #endif #ifndef isdigit #define isdigit(c) ((c) >= '0' && (c) <= '9') #endif #ifndef isxupper #define isxupper(c) ((c) >= 'A' && (c) <= 'F') #endif #ifndef isxlower #define isxlower(c) ((c) >= 'a' && (c) <= 'f') #endif #ifndef isxdigit #define isxdigit(c) (isdigit(c) || isxupper(c) || isxlower(c)) #endif #endif #ifdef __cplusplus } #endif /* Reserved file operation macro interface */ #define SECUREC_LOCK_FILE(s) #define SECUREC_UNLOCK_FILE(s) #define SECUREC_LOCK_STDIN(i, s) #define SECUREC_UNLOCK_STDIN(i, s) #endif ================================================ FILE: third_party/securec/src/securecutil.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* Avoid duplicate header files,not include securecutil.h */ #include "securecutil.h" #if defined(ANDROID) && (SECUREC_HAVE_WCTOMB || SECUREC_HAVE_MBTOWC) #include #if SECUREC_HAVE_WCTOMB /* * Convert wide characters to narrow multi-bytes */ int wctomb(char *s, wchar_t wc) { return wcrtomb(s, wc, NULL); } #endif #if SECUREC_HAVE_MBTOWC /* * Converting narrow multi-byte characters to wide characters */ int mbtowc(wchar_t *pwc, const char *s, size_t n) { return mbrtowc(pwc, s, n, NULL); } #endif #endif /* high Num << 8 | num of SPC Ver */ #define SECUREC_C_VERSION (0x5 << 8) #define SECUREC_SPC_VERSION 7 #define SECUREC_VERSION_STR "Huawei Secure C V100R001C01SPC007B002" /* SPC verNumber<->verStr like: * 0X201<->C01 * 0X202<->SPC001 Redefine numbers after this version * 0X502<->SPC002 * 0X503<->SPC003 * ... * 0X50a<->SPC010 * 0X50b<->SPC011 * ... */ /* CP verNumber<->verStr like: * 0X601<->CP0001 * 0X602<->CP0002 * ... */ const char *GetHwSecureCVersion(unsigned short *verNumber) { if (verNumber != NULL) { *verNumber = (unsigned short)(SECUREC_C_VERSION | SECUREC_SPC_VERSION); } return SECUREC_VERSION_STR; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(GetHwSecureCVersion); #endif ================================================ FILE: third_party/securec/src/securecutil.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F #define SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F #include "securec.h" #if (defined(_MSC_VER)) && (_MSC_VER >= 1400) #define SECUREC_MASK_MSVC_CRT_WARNING __pragma(warning(push)) \ __pragma(warning(disable:4996 4127)) #define SECUREC_END_MASK_MSVC_CRT_WARNING __pragma(warning(pop)) #else #define SECUREC_MASK_MSVC_CRT_WARNING #define SECUREC_END_MASK_MSVC_CRT_WARNING #endif #define SECUREC_WHILE_ZERO SECUREC_MASK_MSVC_CRT_WARNING while (0) SECUREC_END_MASK_MSVC_CRT_WARNING #ifndef SECUREC_HAVE_STRNLEN #if (defined(_XOPEN_SOURCE) && _XOPEN_SOURCE >= 700) || (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200809L) #if SECUREC_IN_KERNEL #define SECUREC_HAVE_STRNLEN 0 #else #if defined(__GLIBC__) && __GLIBC__ >= 2 && defined(__GLIBC_MINOR__) && __GLIBC_MINOR__ >= 10 #define SECUREC_HAVE_STRNLEN 1 #else #define SECUREC_HAVE_STRNLEN 0 #endif #endif #else #define SECUREC_HAVE_STRNLEN 0 #endif #endif #if SECUREC_IN_KERNEL /* in kernel disbale functions */ #ifndef SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_SCANF_FILE 0 #endif #ifndef SECUREC_ENABLE_SCANF_FLOAT #define SECUREC_ENABLE_SCANF_FLOAT 0 #endif #ifndef SECUREC_ENABLE_SPRINTF_FLOAT #define SECUREC_ENABLE_SPRINTF_FLOAT 0 #endif #ifndef SECUREC_HAVE_MBTOWC #define SECUREC_HAVE_MBTOWC 0 #endif #ifndef SECUREC_HAVE_WCTOMB #define SECUREC_HAVE_WCTOMB 0 #endif #ifndef SECUREC_HAVE_WCHART #define SECUREC_HAVE_WCHART 0 #endif #else /* no in kernel */ /* Systems that do not support file, can define this macro to 0. */ #ifndef SECUREC_ENABLE_SCANF_FILE #define SECUREC_ENABLE_SCANF_FILE 1 #endif #ifndef SECUREC_ENABLE_SCANF_FLOAT #define SECUREC_ENABLE_SCANF_FLOAT 1 #endif /* Systems that do not support float, can define this macro to 0. */ #ifndef SECUREC_ENABLE_SPRINTF_FLOAT #define SECUREC_ENABLE_SPRINTF_FLOAT 1 #endif #ifndef SECUREC_HAVE_MBTOWC #define SECUREC_HAVE_MBTOWC 1 #endif #ifndef SECUREC_HAVE_WCTOMB #define SECUREC_HAVE_WCTOMB 1 #endif #ifndef SECUREC_HAVE_WCHART #define SECUREC_HAVE_WCHART 1 #endif #endif #define SECUREC_INT_MAX 2147483647 #define SECUREC_MUL_SIXTEEN(x) ((x) << 4) #define SECUREC_MUL_EIGHT(x) ((x) << 3) #define SECUREC_MUL_TEN(x) ((((x) << 2) + (x)) << 1) /* Limited format input and output width */ #define SECUREC_MAX_WIDTH_LEN_DIV_TEN 21474836 #define SECUREC_MAX_WIDTH_LEN SECUREC_MUL_TEN(SECUREC_MAX_WIDTH_LEN_DIV_TEN) /* Is the x multiplied by 10 greater than */ #define SECUREC_MUL_TEN_ADD_BEYOND_MAX(x) (((x) > SECUREC_MAX_WIDTH_LEN_DIV_TEN)) #define SECUREC_FLOAT_BUFSIZE (309 + 40) /* Max length of double value */ #define SECUREC_FLOAT_BUFSIZE_LB (4932 + 40) /* Max length of long double value */ #define SECUREC_FLOAT_DEFAULT_PRECISION 6 /* This macro does not handle pointer equality or integer overflow */ #define SECUREC_MEMORY_NO_OVERLAP(dest, src, count) \ (((src) < (dest) && ((const char *)(src) + (count)) <= (char *)(dest)) || \ ((dest) < (src) && ((char *)(dest) + (count)) <= (const char *)(src))) #define SECUREC_MEMORY_IS_OVERLAP(dest, src, count) \ (((src) < (dest) && ((const char *)(src) + (count)) > (char *)(dest)) || \ ((dest) < (src) && ((char *)(dest) + (count)) > (const char *)(src))) /* * Check whether the strings overlap, len is the length of the string not include terminator * Length is related to data type char or wchar , do not force conversion of types */ #define SECUREC_STRING_NO_OVERLAP(dest, src, len) \ (((src) < (dest) && ((src) + (len)) < (dest)) || \ ((dest) < (src) && ((dest) + (len)) < (src))) /* * Check whether the strings overlap for strcpy wcscpy function, dest len and src Len are not include terminator * Length is related to data type char or wchar , do not force conversion of types */ #define SECUREC_STRING_IS_OVERLAP(dest, src, len) \ (((src) < (dest) && ((src) + (len)) >= (dest)) || \ ((dest) < (src) && ((dest) + (len)) >= (src))) /* * Check whether the strings overlap for strcat wcscat function, dest len and src Len are not include terminator * Length is related to data type char or wchar , do not force conversion of types */ #define SECUREC_CAT_STRING_IS_OVERLAP(dest, destLen, src, srcLen) \ (((dest) < (src) && ((dest) + (destLen) + (srcLen)) >= (src)) || \ ((src) < (dest) && ((src) + (srcLen)) >= (dest))) #if SECUREC_HAVE_STRNLEN #define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ *(outLen) = strnlen((str), (maxLen)); \ } SECUREC_WHILE_ZERO #define SECUREC_CALC_STR_LEN_OPT(str, maxLen, outLen) do { \ if ((maxLen) > 8) { \ /* Optimization or len less then 8 */ \ if (*((str) + 0) == '\0') { \ *(outLen) = 0; \ } else if (*((str) + 1) == '\0') { \ *(outLen) = 1; \ } else if (*((str) + 2) == '\0') { \ *(outLen) = 2; \ } else if (*((str) + 3) == '\0') { \ *(outLen) = 3; \ } else if (*((str) + 4) == '\0') { \ *(outLen) = 4; \ } else if (*((str) + 5) == '\0') { \ *(outLen) = 5; \ } else if (*((str) + 6) == '\0') { \ *(outLen) = 6; \ } else if (*((str) + 7) == '\0') { \ *(outLen) = 7; \ } else if (*((str) + 8) == '\0') { \ /* Optimization with a length of 8 */ \ *(outLen) = 8; \ } else { \ /* The offset is 8 because the performance of 8 byte alignment is high */ \ *(outLen) = 8 + strnlen((str) + 8, (maxLen) - 8); \ } \ } else { \ SECUREC_CALC_STR_LEN((str), (maxLen), (outLen)); \ } \ } SECUREC_WHILE_ZERO #else #define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ const char *strEnd = (const char *)(str); \ size_t availableSize = (size_t)(maxLen); \ while (availableSize > 0 && *strEnd != '\0') { \ --availableSize; \ ++strEnd; \ } \ *(outLen) = (size_t)(strEnd - (str)); \ } SECUREC_WHILE_ZERO #define SECUREC_CALC_STR_LEN_OPT SECUREC_CALC_STR_LEN #endif #define SECUREC_CALC_WSTR_LEN(str, maxLen, outLen) do { \ const wchar_t *strEnd = (const wchar_t *)(str); \ *(outLen) = 0; \ while (*(outLen) < (maxLen) && *strEnd != L'\0') { \ *(outLen) = *(outLen) + 1; \ ++strEnd; \ } \ } SECUREC_WHILE_ZERO #ifdef SECUREC_FORMAT_OUTPUT_INPUT #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) || defined(__ARMCC_VERSION) typedef __int64 SecInt64; typedef unsigned __int64 SecUnsignedInt64; #if defined(__ARMCC_VERSION) typedef unsigned int SecUnsignedInt32; #else typedef unsigned __int32 SecUnsignedInt32; #endif #else typedef unsigned int SecUnsignedInt32; typedef long long SecInt64; typedef unsigned long long SecUnsignedInt64; #endif #ifdef SECUREC_FOR_WCHAR #if defined(SECUREC_VXWORKS_PLATFORM) && !defined(__WINT_TYPE__) typedef wchar_t wint_t; #endif typedef wchar_t SecChar; typedef wchar_t SecUnsignedChar; typedef wint_t SecInt; typedef wint_t SecUnsignedInt; #else /* no SECUREC_FOR_WCHAR */ typedef char SecChar; typedef unsigned char SecUnsignedChar; typedef int SecInt; typedef unsigned int SecUnsignedInt; #endif #endif /* Determine whether the address is 8-byte aligned * Some systems do not have uintptr_t type, so use NULL to clear tool alarm 507 */ #define SECUREC_ADDR_ALIGNED_8(addr) (SecIsAddrAligned8((addr), NULL) == 0) /* If you define the memory allocation function, * you need to define the function prototype. You can define this macro as a header file. */ #if defined(SECUREC_MALLOC_PROTOTYPE) SECUREC_MALLOC_PROTOTYPE #endif #ifndef SECUREC_MALLOC #define SECUREC_MALLOC(x) malloc((size_t)(x)) #endif #ifndef SECUREC_FREE #define SECUREC_FREE(x) free((void *)(x)) #endif /* struct for performance */ typedef struct { unsigned char buf[1]; /* Performance optimization code structure assignment length 1 bytes */ } SecStrBuf1; typedef struct { unsigned char buf[2]; /* Performance optimization code structure assignment length 2 bytes */ } SecStrBuf2; typedef struct { unsigned char buf[3]; /* Performance optimization code structure assignment length 3 bytes */ } SecStrBuf3; typedef struct { unsigned char buf[4]; /* Performance optimization code structure assignment length 4 bytes */ } SecStrBuf4; typedef struct { unsigned char buf[5]; /* Performance optimization code structure assignment length 5 bytes */ } SecStrBuf5; typedef struct { unsigned char buf[6]; /* Performance optimization code structure assignment length 6 bytes */ } SecStrBuf6; typedef struct { unsigned char buf[7]; /* Performance optimization code structure assignment length 7 bytes */ } SecStrBuf7; typedef struct { unsigned char buf[8]; /* Performance optimization code structure assignment length 8 bytes */ } SecStrBuf8; typedef struct { unsigned char buf[9]; /* Performance optimization code structure assignment length 9 bytes */ } SecStrBuf9; typedef struct { unsigned char buf[10]; /* Performance optimization code structure assignment length 10 bytes */ } SecStrBuf10; typedef struct { unsigned char buf[11]; /* Performance optimization code structure assignment length 11 bytes */ } SecStrBuf11; typedef struct { unsigned char buf[12]; /* Performance optimization code structure assignment length 12 bytes */ } SecStrBuf12; typedef struct { unsigned char buf[13]; /* Performance optimization code structure assignment length 13 bytes */ } SecStrBuf13; typedef struct { unsigned char buf[14]; /* Performance optimization code structure assignment length 14 bytes */ } SecStrBuf14; typedef struct { unsigned char buf[15]; /* Performance optimization code structure assignment length 15 bytes */ } SecStrBuf15; typedef struct { unsigned char buf[16]; /* Performance optimization code structure assignment length 16 bytes */ } SecStrBuf16; typedef struct { unsigned char buf[17]; /* Performance optimization code structure assignment length 17 bytes */ } SecStrBuf17; typedef struct { unsigned char buf[18]; /* Performance optimization code structure assignment length 18 bytes */ } SecStrBuf18; typedef struct { unsigned char buf[19]; /* Performance optimization code structure assignment length 19 bytes */ } SecStrBuf19; typedef struct { unsigned char buf[20]; /* Performance optimization code structure assignment length 20 bytes */ } SecStrBuf20; typedef struct { unsigned char buf[21]; /* Performance optimization code structure assignment length 21 bytes */ } SecStrBuf21; typedef struct { unsigned char buf[22]; /* Performance optimization code structure assignment length 22 bytes */ } SecStrBuf22; typedef struct { unsigned char buf[23]; /* Performance optimization code structure assignment length 23 bytes */ } SecStrBuf23; typedef struct { unsigned char buf[24]; /* Performance optimization code structure assignment length 24 bytes */ } SecStrBuf24; typedef struct { unsigned char buf[25]; /* Performance optimization code structure assignment length 25 bytes */ } SecStrBuf25; typedef struct { unsigned char buf[26]; /* Performance optimization code structure assignment length 26 bytes */ } SecStrBuf26; typedef struct { unsigned char buf[27]; /* Performance optimization code structure assignment length 27 bytes */ } SecStrBuf27; typedef struct { unsigned char buf[28]; /* Performance optimization code structure assignment length 28 bytes */ } SecStrBuf28; typedef struct { unsigned char buf[29]; /* Performance optimization code structure assignment length 29 bytes */ } SecStrBuf29; typedef struct { unsigned char buf[30]; /* Performance optimization code structure assignment length 30 bytes */ } SecStrBuf30; typedef struct { unsigned char buf[31]; /* Performance optimization code structure assignment length 31 bytes */ } SecStrBuf31; typedef struct { unsigned char buf[32]; /* Performance optimization code structure assignment length 32 bytes */ } SecStrBuf32; typedef struct { unsigned char buf[33]; /* Performance optimization code structure assignment length 33 bytes */ } SecStrBuf33; typedef struct { unsigned char buf[34]; /* Performance optimization code structure assignment length 34 bytes */ } SecStrBuf34; typedef struct { unsigned char buf[35]; /* Performance optimization code structure assignment length 35 bytes */ } SecStrBuf35; typedef struct { unsigned char buf[36]; /* Performance optimization code structure assignment length 36 bytes */ } SecStrBuf36; typedef struct { unsigned char buf[37]; /* Performance optimization code structure assignment length 37 bytes */ } SecStrBuf37; typedef struct { unsigned char buf[38]; /* Performance optimization code structure assignment length 38 bytes */ } SecStrBuf38; typedef struct { unsigned char buf[39]; /* Performance optimization code structure assignment length 39 bytes */ } SecStrBuf39; typedef struct { unsigned char buf[40]; /* Performance optimization code structure assignment length 40 bytes */ } SecStrBuf40; typedef struct { unsigned char buf[41]; /* Performance optimization code structure assignment length 41 bytes */ } SecStrBuf41; typedef struct { unsigned char buf[42]; /* Performance optimization code structure assignment length 42 bytes */ } SecStrBuf42; typedef struct { unsigned char buf[43]; /* Performance optimization code structure assignment length 43 bytes */ } SecStrBuf43; typedef struct { unsigned char buf[44]; /* Performance optimization code structure assignment length 44 bytes */ } SecStrBuf44; typedef struct { unsigned char buf[45]; /* Performance optimization code structure assignment length 45 bytes */ } SecStrBuf45; typedef struct { unsigned char buf[46]; /* Performance optimization code structure assignment length 46 bytes */ } SecStrBuf46; typedef struct { unsigned char buf[47]; /* Performance optimization code structure assignment length 47 bytes */ } SecStrBuf47; typedef struct { unsigned char buf[48]; /* Performance optimization code structure assignment length 48 bytes */ } SecStrBuf48; typedef struct { unsigned char buf[49]; /* Performance optimization code structure assignment length 49 bytes */ } SecStrBuf49; typedef struct { unsigned char buf[50]; /* Performance optimization code structure assignment length 50 bytes */ } SecStrBuf50; typedef struct { unsigned char buf[51]; /* Performance optimization code structure assignment length 51 bytes */ } SecStrBuf51; typedef struct { unsigned char buf[52]; /* Performance optimization code structure assignment length 52 bytes */ } SecStrBuf52; typedef struct { unsigned char buf[53]; /* Performance optimization code structure assignment length 53 bytes */ } SecStrBuf53; typedef struct { unsigned char buf[54]; /* Performance optimization code structure assignment length 54 bytes */ } SecStrBuf54; typedef struct { unsigned char buf[55]; /* Performance optimization code structure assignment length 55 bytes */ } SecStrBuf55; typedef struct { unsigned char buf[56]; /* Performance optimization code structure assignment length 56 bytes */ } SecStrBuf56; typedef struct { unsigned char buf[57]; /* Performance optimization code structure assignment length 57 bytes */ } SecStrBuf57; typedef struct { unsigned char buf[58]; /* Performance optimization code structure assignment length 58 bytes */ } SecStrBuf58; typedef struct { unsigned char buf[59]; /* Performance optimization code structure assignment length 59 bytes */ } SecStrBuf59; typedef struct { unsigned char buf[60]; /* Performance optimization code structure assignment length 60 bytes */ } SecStrBuf60; typedef struct { unsigned char buf[61]; /* Performance optimization code structure assignment length 61 bytes */ } SecStrBuf61; typedef struct { unsigned char buf[62]; /* Performance optimization code structure assignment length 62 bytes */ } SecStrBuf62; typedef struct { unsigned char buf[63]; /* Performance optimization code structure assignment length 63 bytes */ } SecStrBuf63; typedef struct { unsigned char buf[64]; /* Performance optimization code structure assignment length 64 bytes */ } SecStrBuf64; /* User can change the error handler by modify the following definition, * such as logging the detail error in file. */ #if defined(_DEBUG) || defined(DEBUG) #if defined(SECUREC_ERROR_HANDLER_BY_ASSERT) #define SECUREC_ERROR_INVALID_PARAMTER(msg) assert(msg "invalid argument" == NULL) #define SECUREC_ERROR_INVALID_RANGE(msg) assert(msg "invalid dest buffer size" == NULL) #define SECUREC_ERROR_BUFFER_OVERLAP(msg) assert(msg "buffer overlap" == NULL) #elif defined(SECUREC_ERROR_HANDLER_BY_PRINTF) #if SECUREC_IN_KERNEL #define SECUREC_ERROR_INVALID_PARAMTER(msg) printk("%s invalid argument\n", msg) #define SECUREC_ERROR_INVALID_RANGE(msg) printk("%s invalid dest buffer size\n", msg) #define SECUREC_ERROR_BUFFER_OVERLAP(msg) printk("%s buffer overlap\n", msg) #else #define SECUREC_ERROR_INVALID_PARAMTER(msg) printf("%s invalid argument\n", msg) #define SECUREC_ERROR_INVALID_RANGE(msg) printf("%s invalid dest buffer size\n", msg) #define SECUREC_ERROR_BUFFER_OVERLAP(msg) printf("%s buffer overlap\n", msg) #endif #elif defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) #define SECUREC_ERROR_INVALID_PARAMTER(msg) LogSecureCRuntimeError(msg " EINVAL\n") #define SECUREC_ERROR_INVALID_RANGE(msg) LogSecureCRuntimeError(msg " ERANGE\n") #define SECUREC_ERROR_BUFFER_OVERLAP(msg) LogSecureCRuntimeError(msg " EOVERLAP\n") #else /* no HANDLER is defined */ #define SECUREC_ERROR_INVALID_PARAMTER(msg) ((void)0) #define SECUREC_ERROR_INVALID_RANGE(msg) ((void)0) #define SECUREC_ERROR_BUFFER_OVERLAP(msg) ((void)0) #endif #else /* no DEBUG */ #define SECUREC_ERROR_INVALID_PARAMTER(msg) ((void)0) #define SECUREC_ERROR_INVALID_RANGE(msg) ((void)0) #define SECUREC_ERROR_BUFFER_OVERLAP(msg) ((void)0) #endif #ifdef __cplusplus extern "C" { #endif /* assembly language memory copy and memory set for X86 or MIPS ... */ #ifdef SECUREC_USE_ASM extern void *memcpy_opt(void *, const void *, size_t); extern void *memset_opt(void *, int, size_t); #endif #if defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) extern void LogSecureCRuntimeError(const char *errDetail); #endif #ifdef SECUREC_INLINE_DO_MEMCPY static void SecDoMemcpy(void *dest, const void *src, size_t count) { /* * if SECUREC_USE_ASM macro is enabled, it will call assembly language function to improve performance. */ #ifdef SECUREC_USE_ASM (void)memcpy_opt(dest, src, count); #else /* large enough, let system API do it */ (void)memcpy(dest, src, count); #endif } #endif #ifdef SECUREC_INLINE_DO_MEMSET static void SecDoMemset(void *dest, int c, size_t count) { #ifdef SECUREC_USE_ASM (void)memset_opt(dest, c, count); #else (void)memset(dest, c, count); #endif } #endif #ifdef SECUREC_INLINE_STR_LEN /* The function compiler will be inlined and not placed in other files */ static size_t SecStrMinLen(const char *str, size_t maxLen) { size_t len; SECUREC_CALC_STR_LEN(str, maxLen, &len); return len; } #endif #ifdef SECUREC_INLINE_STR_LEN_OPT /* The function compiler will be inlined and not placed in other files */ static size_t SecStrMinLenOpt(const char *str, size_t maxLen) { size_t len; SECUREC_CALC_STR_LEN_OPT(str, maxLen, &len); return len; } #endif #ifdef __cplusplus } #endif /* __cplusplus */ #endif ================================================ FILE: third_party/securec/src/secureinput_a.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_FORMAT_OUTPUT_INPUT 1 #ifdef SECUREC_FOR_WCHAR #undef SECUREC_FOR_WCHAR #endif #include "secinput.h" #include "input.inl" ================================================ FILE: third_party/securec/src/secureinput_w.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* if some platforms don't have wchar.h, dont't include it */ #if !(defined(SECUREC_VXWORKS_PLATFORM)) /* This header file is placed below secinput.h, which will cause tool alarm, * but If there is no macro above, it will cause vs2010 compiling alarm */ #if defined(_MSC_VER) && (_MSC_VER >= 1400) #ifndef __STDC_WANT_SECURE_LIB__ /* The order of adjustment is to eliminate alarm of Duplicate Block */ #define __STDC_WANT_SECURE_LIB__ 0 #endif #ifndef _CRTIMP_ALTERNATIVE #define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ #endif #endif #include #endif #define SECUREC_ENABLE_WCHAR_FUNC 0 #define SECUREC_FORMAT_OUTPUT_INPUT 1 #ifndef SECUREC_FOR_WCHAR #define SECUREC_FOR_WCHAR #endif #include "secinput.h" #ifndef WEOF #define WEOF ((wchar_t)(-1)) #endif #include "input.inl" ================================================ FILE: third_party/securec/src/secureprintoutput.h ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef SECUREPRINTOUTPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C #define SECUREPRINTOUTPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C #include "securecutil.h" /* flag definitions */ /* Using macros instead of enumerations is because some of the enumerated types under the compiler are 16bit. */ #define SECUREC_FLAG_SIGN 0x00001U #define SECUREC_FLAG_SIGN_SPACE 0x00002U #define SECUREC_FLAG_LEFT 0x00004U #define SECUREC_FLAG_LEADZERO 0x00008U #define SECUREC_FLAG_LONG 0x00010U #define SECUREC_FLAG_SHORT 0x00020U #define SECUREC_FLAG_SIGNED 0x00040U #define SECUREC_FLAG_ALTERNATE 0x00080U #define SECUREC_FLAG_NEGATIVE 0x00100U #define SECUREC_FLAG_FORCE_OCTAL 0x00200U #define SECUREC_FLAG_LONG_DOUBLE 0x00400U #define SECUREC_FLAG_WIDECHAR 0x00800U #define SECUREC_FLAG_LONGLONG 0x01000U #define SECUREC_FLAG_CHAR 0x02000U #define SECUREC_FLAG_POINTER 0x04000U #define SECUREC_FLAG_I64 0x08000U #define SECUREC_FLAG_PTRDIFF 0x10000U #define SECUREC_FLAG_SIZE 0x20000U #ifdef SECUREC_COMPATIBLE_LINUX_FORMAT #define SECUREC_FLAG_INTMAX 0x40000U #endif /* state definitions. Identify the status of the current format */ typedef enum { STAT_NORMAL, STAT_PERCENT, STAT_FLAG, STAT_WIDTH, STAT_DOT, STAT_PRECIS, STAT_SIZE, STAT_TYPE, STAT_INVALID } SecFmtState; /* Format output buffer pointer and available size */ typedef struct { int count; char *cur; } SecPrintfStream; #ifndef SECUREC_BUFFER_SIZE #ifdef SECUREC_STACK_SIZE_LESS_THAN_1K /* SECUREC_BUFFER_SIZE Can not be less than 23 , * the length of the octal representation of 64-bit integers with zero lead */ #define SECUREC_BUFFER_SIZE 256 #else #define SECUREC_BUFFER_SIZE 512 #endif #endif #if SECUREC_BUFFER_SIZE < 23 #error SECUREC_BUFFER_SIZE Can not be less than 23 #endif #define SECUREC_MAX_PRECISION SECUREC_BUFFER_SIZE /* max. # bytes in multibyte char ,see MB_LEN_MAX */ #define SECUREC_MB_LEN 16 /* The return value of the internal function, which is returned when truncated */ #define SECUREC_PRINTF_TRUNCATE (-2) #ifdef __cplusplus extern "C" { #endif extern int SecVsnprintfImpl(char *string, size_t count, const char *format, va_list argList); #if SECUREC_IN_KERNEL == 0 extern int SecVswprintfImpl(wchar_t *string, size_t sizeInWchar, const wchar_t *format, va_list argList); #endif #ifdef __cplusplus } #endif #endif ================================================ FILE: third_party/securec/src/secureprintoutput_a.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #define SECUREC_FORMAT_OUTPUT_INPUT 1 #ifdef SECUREC_FOR_WCHAR #undef SECUREC_FOR_WCHAR #endif #include "secureprintoutput.h" #define SECUREC_CHAR(x) x #define SECUREC_WRITE_MULTI_CHAR SecWriteMultiChar #define SECUREC_WRITE_STRING SecWriteString #ifndef EOF #define EOF (-1) #endif /* put a char to output */ #define SECUREC_PUTC(c, outStream) ((--(outStream)->count >= 0) ? \ (int)((unsigned int)(unsigned char)(*((outStream)->cur++) = (char)(c)) & 0xff) : EOF) /* to clear e835 */ #define SECUREC_PUTC_ZERO(outStream) ((--(outStream)->count >= 0) ? \ ((*((outStream)->cur++) = (char)('\0'))) : EOF) static void SecWriteMultiChar(char ch, int num, SecPrintfStream *f, int *pnumwritten); static void SecWriteString(const char *string, int len, SecPrintfStream *f, int *pnumwritten); #include "output.inl" /* * Wide character formatted output implementation */ int SecVsnprintfImpl(char *string, size_t count, const char *format, va_list argList) { SecPrintfStream str; int retVal; str.count = (int)count; /* this count include \0 character, Must be greater than zero */ str.cur = string; retVal = SecOutputS(&str, format, argList); if ((retVal >= 0) && (SECUREC_PUTC_ZERO(&str) != EOF)) { return retVal; } else if (str.count < 0) { /* the buffer was too small; we return truncation */ string[count - 1] = '\0'; return SECUREC_PRINTF_TRUNCATE; } string[0] = '\0'; /* empty the dest strDest */ return -1; } /* * Sec write Wide character */ static void SecWriteMultiChar(char ch, int num, SecPrintfStream *f, int *pnumwritten) { int count = num; while (count-- > 0) { if (SECUREC_PUTC(ch, f) == EOF) { *pnumwritten = -1; break; } else { *pnumwritten = *pnumwritten + 1; } } } /* * Sec write string function */ static void SecWriteString(const char *string, int len, SecPrintfStream *f, int *pnumwritten) { const char *str = string; int count = len; while (count-- > 0) { if (SECUREC_PUTC(*str, f) == EOF) { *pnumwritten = -1; break; } else { *pnumwritten = *pnumwritten + 1; ++str; } } } ================================================ FILE: third_party/securec/src/secureprintoutput_w.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* if some platforms don't have wchar.h, dont't include it */ #if !(defined(SECUREC_VXWORKS_PLATFORM)) /* This header file is placed below secinput.h, which will cause tool alarm, * but if there is no macro above, it will cause compiling alarm */ #if defined(_MSC_VER) && (_MSC_VER >= 1400) #ifndef _CRTIMP_ALTERNATIVE #define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ #endif #ifndef __STDC_WANT_SECURE_LIB__ #define __STDC_WANT_SECURE_LIB__ 0 #endif #endif #include #endif #define SECUREC_ENABLE_WCHAR_FUNC 0 #define SECUREC_INLINE_DO_MEMCPY 1 #define SECUREC_FORMAT_OUTPUT_INPUT 1 #ifndef SECUREC_FOR_WCHAR #define SECUREC_FOR_WCHAR #endif #include "secureprintoutput.h" #ifndef WEOF #define WEOF ((wchar_t)(-1)) #endif #define SECUREC_CHAR(x) L ## x #define SECUREC_WRITE_MULTI_CHAR SecWriteMultiCharW #define SECUREC_WRITE_STRING SecWriteStringW static void SecWriteCharW(wchar_t ch, SecPrintfStream *f, int *pnumwritten); static void SecWriteMultiCharW(wchar_t ch, int num, SecPrintfStream *f, int *pnumwritten); static void SecWriteStringW(const wchar_t *string, int len, SecPrintfStream *f, int *pnumwritten); static int SecPutWcharStrEndingZero(SecPrintfStream *str, int zeroCount); #include "output.inl" /* * Wide character formatted output implementation */ int SecVswprintfImpl(wchar_t *string, size_t sizeInWchar, const wchar_t *format, va_list argList) { SecPrintfStream str; int retVal; /* If initialization causes e838 */ str.cur = (char *)string; /* this count include \0 character, Must be greater than zero */ str.count = (int)(sizeInWchar * sizeof(wchar_t)); retVal = SecOutputSW(&str, format, argList); if ((retVal >= 0) && SecPutWcharStrEndingZero(&str, (int)sizeof(wchar_t))) { return (retVal); } else if (str.count < 0) { /* the buffer was too small; we return truncation */ string[sizeInWchar - 1] = L'\0'; return SECUREC_PRINTF_TRUNCATE; } string[0] = L'\0'; return -1; } /* * Output one zero character zero into the SecPrintfStream structure */ static int SecPutZeroChar(SecPrintfStream *str) { if (str->count > 0) { *(str->cur) = (char)('\0'); str->count = str->count - 1; str->cur = str->cur + 1; return 0; } return -1; } /* * Output a wide character zero end into the SecPrintfStream structure */ static int SecPutWcharStrEndingZero(SecPrintfStream *str, int zeroCount) { int succeed = 0; int i = 0; while (i < zeroCount && (SecPutZeroChar(str) == 0)) { ++i; } if (i == zeroCount) { succeed = 1; } return succeed; } /* * Output a wide character into the SecPrintfStream structure */ static wchar_t SecPutCharW(wchar_t ch, SecPrintfStream *f) { wchar_t wcRet = 0; if (((f)->count -= (int)sizeof(wchar_t)) >= 0) { *(wchar_t *)(void *)(f->cur) = ch; f->cur += sizeof(wchar_t); wcRet = ch; } else { wcRet = (wchar_t)WEOF; } return wcRet; } /* * Output a wide character into the SecPrintfStream structure, returns the number of characters written */ static void SecWriteCharW(wchar_t ch, SecPrintfStream *f, int *pnumwritten) { if (SecPutCharW(ch, f) == (wchar_t)WEOF) { *pnumwritten = -1; } else { *pnumwritten = *pnumwritten + 1; } } /* * Output multiple wide character into the SecPrintfStream structure, returns the number of characters written */ static void SecWriteMultiCharW(wchar_t ch, int num, SecPrintfStream *f, int *pnumwritten) { int count = num; while (count-- > 0) { SecWriteCharW(ch, f, pnumwritten); if (*pnumwritten == -1) { break; } } } /* * Output a wide string into the SecPrintfStream structure, returns the number of characters written */ static void SecWriteStringW(const wchar_t *string, int len, SecPrintfStream *f, int *pnumwritten) { const wchar_t *str = string; int count = len; while (count-- > 0) { SecWriteCharW(*str++, f, pnumwritten); if (*pnumwritten == -1) { break; } } } ================================================ FILE: third_party/securec/src/snprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" #if SECUREC_ENABLE_SNPRINTF /* * * The snprintf_s function is equivalent to the snprintf function * except for the parameter destMax/count and the explicit runtime-constraints violation * The snprintf_s function formats and stores count or fewer characters in * strDest and appends a terminating null. Each argument (if any) is converted * and output according to the corresponding format specification in format. * The formatting is consistent with the printf family of functions; If copying * occurs between strings that overlap, the behavior is undefined. * * * strDest Storage location for the output. * destMax The size of the storage location for output. Size * in bytes for snprintf_s or size in words for snwprintf_s. * count Maximum number of character to store. * format Format-control string. * ... Optional arguments. * * * strDest is updated * * * return the number of characters written, not including the terminating null * return -1 if an error occurs. * return -1 if count < destMax and the output string has been truncated * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid * */ int snprintf_s(char *strDest, size_t destMax, size_t count, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vsnprintf_s(strDest, destMax, count, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(snprintf_s); #endif #endif #if SECUREC_SNPRINTF_TRUNCATED /* * * The snprintf_truncated_s function is equivalent to the snprintf function * except for the parameter destMax/count and the explicit runtime-constraints violation * The snprintf_truncated_s function formats and stores count or fewer characters in * strDest and appends a terminating null. Each argument (if any) is converted * and output according to the corresponding format specification in format. * The formatting is consistent with the printf family of functions; If copying * occurs between strings that overlap, the behavior is undefined. * * * strDest Storage location for the output. * destMax The size of the storage location for output. Size * in bytes for snprintf_truncated_s or size in words for snwprintf_s. * format Format-control string. * ... Optional arguments. * * * strDest is updated * * * return the number of characters written, not including the terminating null * return -1 if an error occurs. * return destMax-1 if output string has been truncated * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid * */ int snprintf_truncated_s(char *strDest, size_t destMax, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vsnprintf_truncated_s(strDest, destMax, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(snprintf_truncated_s); #endif #endif ================================================ FILE: third_party/securec/src/sprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The sprintf_s function is equivalent to the sprintf function * except for the parameter destMax and the explicit runtime-constraints violation * The sprintf_s function formats and stores a series of characters and values * in strDest. Each argument (if any) is converted and output according to * the corresponding format specification in format. The format consists of * ordinary characters and has the same form and function as the format argument * for printf. A null character is appended after the last character written. * If copying occurs between strings that overlap, the behavior is undefined. * * * strDest Storage location for output. * destMax Maximum number of characters to store. * format Format-control string. * ... Optional arguments * * * strDest is updated * * * return the number of bytes stored in strDest, not counting the terminating null character. * return -1 if an error occurred. * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int sprintf_s(char *strDest, size_t destMax, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vsprintf_s(strDest, destMax, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(sprintf_s); #endif ================================================ FILE: third_party/securec/src/sscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The sscanf_s function is equivalent to fscanf_s, * except that input is obtained from a string (specified by the argument buffer) rather than from a stream * The sscanf function reads data from buffer into the location given by each * argument. Every argument must be a pointer to a variable with a type that * corresponds to a type specifier in format. The format argument controls the * interpretation of the input fields and has the same form and function as * the format argument for the scanf function. * If copying takes place between strings that overlap, the behavior is undefined. * * * buffer Stored data. * format Format control string, see Format Specifications. * ... Optional arguments. * * * ... The converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int sscanf_s(const char *buffer, const char *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vsscanf_s(buffer, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(sscanf_s); #endif ================================================ FILE: third_party/securec/src/strcat_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_STR_LEN 1 #define SECUREC_INLINE_STR_LEN_OPT 1 #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" /* * Befor this function, the basic parameter checking has been done */ static errno_t SecDoStrcat(char *strDest, size_t destMax, const char *strSrc) { size_t destLen = SecStrMinLen(strDest, destMax); /* Only optimize strSrc, do not apply this function to strDest */ size_t srcLen = SecStrMinLenOpt(strSrc, destMax - destLen); if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { strDest[0] = '\0'; if (strDest + destLen <= strSrc && destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_BUFFER_OVERLAP("strcat_s"); return EOVERLAP_AND_RESET; } if (srcLen + destLen >= destMax || strDest == strSrc) { strDest[0] = '\0'; if (destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_INVALID_RANGE("strcat_s"); return ERANGE_AND_RESET; } SecDoMemcpy(strDest + destLen, strSrc, srcLen + 1); /* single character length include \0 */ return EOK; } /* * * The strcat_s function appends a copy of the string pointed to by strSrc (including the terminating null character) * to the end of the string pointed to by strDest. * The initial character of strSrc overwrites the terminating null character of strDest. * strcat_s will return EOVERLAP_AND_RESET if the source and destination strings overlap. * * Note that the second parameter is the total size of the buffer, not the * remaining size. * * * strDest Null-terminated destination string buffer. * destMax Size of the destination string buffer. * strSrc Null-terminated source string buffer. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid)or * (strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN) * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t strcat_s(char *strDest, size_t destMax, const char *strSrc) { if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("strcat_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); if (strDest != NULL) { strDest[0] = '\0'; return EINVAL_AND_RESET; } return EINVAL; } return SecDoStrcat(strDest, destMax, strSrc); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(strcat_s); #endif ================================================ FILE: third_party/securec/src/strcpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_STR_LEN 1 #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" #if SECUREC_IN_KERNEL== 0 #ifndef SECUREC_STRCOPY_THRESHOLD_SIZE #define SECUREC_STRCOPY_THRESHOLD_SIZE 32UL #endif /* * Determine whether the address is 8-byte aligned, use static to increase performance * return 0 is aligned */ static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) { return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ } /* The purpose of converting to void is to clean up the alarm */ #define SECUREC_SMALL_STR_COPY do { \ if (SECUREC_ADDR_ALIGNED_8(strDest) && SECUREC_ADDR_ALIGNED_8(strSrc)) { \ /* use struct assignment */ \ switch (srcStrLen) { \ case 1: \ *(SecStrBuf1 *)(void *)strDest = *(const SecStrBuf1 *)(const void *)strSrc; \ break; \ case 2: \ *(SecStrBuf2 *)(void *)strDest = *(const SecStrBuf2 *)(const void *)strSrc; \ break; \ case 3: \ *(SecStrBuf3 *)(void *)strDest = *(const SecStrBuf3 *)(const void *)strSrc; \ break; \ case 4: \ *(SecStrBuf4 *)(void *)strDest = *(const SecStrBuf4 *)(const void *)strSrc; \ break; \ case 5: \ *(SecStrBuf5 *)(void *)strDest = *(const SecStrBuf5 *)(const void *)strSrc; \ break; \ case 6: \ *(SecStrBuf6 *)(void *)strDest = *(const SecStrBuf6 *)(const void *)strSrc; \ break; \ case 7: \ *(SecStrBuf7 *)(void *)strDest = *(const SecStrBuf7 *)(const void *)strSrc; \ break; \ case 8: \ *(SecStrBuf8 *)(void *)strDest = *(const SecStrBuf8 *)(const void *)strSrc; \ break; \ case 9: \ *(SecStrBuf9 *)(void *)strDest = *(const SecStrBuf9 *)(const void *)strSrc; \ break; \ case 10: \ *(SecStrBuf10 *)(void *)strDest = *(const SecStrBuf10 *)(const void *)strSrc; \ break; \ case 11: \ *(SecStrBuf11 *)(void *)strDest = *(const SecStrBuf11 *)(const void *)strSrc; \ break; \ case 12: \ *(SecStrBuf12 *)(void *)strDest = *(const SecStrBuf12 *)(const void *)strSrc; \ break; \ case 13: \ *(SecStrBuf13 *)(void *)strDest = *(const SecStrBuf13 *)(const void *)strSrc; \ break; \ case 14: \ *(SecStrBuf14 *)(void *)strDest = *(const SecStrBuf14 *)(const void *)strSrc; \ break; \ case 15: \ *(SecStrBuf15 *)(void *)strDest = *(const SecStrBuf15 *)(const void *)strSrc; \ break; \ case 16: \ *(SecStrBuf16 *)(void *)strDest = *(const SecStrBuf16 *)(const void *)strSrc; \ break; \ case 17: \ *(SecStrBuf17 *)(void *)strDest = *(const SecStrBuf17 *)(const void *)strSrc; \ break; \ case 18: \ *(SecStrBuf18 *)(void *)strDest = *(const SecStrBuf18 *)(const void *)strSrc; \ break; \ case 19: \ *(SecStrBuf19 *)(void *)strDest = *(const SecStrBuf19 *)(const void *)strSrc; \ break; \ case 20: \ *(SecStrBuf20 *)(void *)strDest = *(const SecStrBuf20 *)(const void *)strSrc; \ break; \ case 21: \ *(SecStrBuf21 *)(void *)strDest = *(const SecStrBuf21 *)(const void *)strSrc; \ break; \ case 22: \ *(SecStrBuf22 *)(void *)strDest = *(const SecStrBuf22 *)(const void *)strSrc; \ break; \ case 23: \ *(SecStrBuf23 *)(void *)strDest = *(const SecStrBuf23 *)(const void *)strSrc; \ break; \ case 24: \ *(SecStrBuf24 *)(void *)strDest = *(const SecStrBuf24 *)(const void *)strSrc; \ break; \ case 25: \ *(SecStrBuf25 *)(void *)strDest = *(const SecStrBuf25 *)(const void *)strSrc; \ break; \ case 26: \ *(SecStrBuf26 *)(void *)strDest = *(const SecStrBuf26 *)(const void *)strSrc; \ break; \ case 27: \ *(SecStrBuf27 *)(void *)strDest = *(const SecStrBuf27 *)(const void *)strSrc; \ break; \ case 28: \ *(SecStrBuf28 *)(void *)strDest = *(const SecStrBuf28 *)(const void *)strSrc; \ break; \ case 29: \ *(SecStrBuf29 *)(void *)strDest = *(const SecStrBuf29 *)(const void *)strSrc; \ break; \ case 30: \ *(SecStrBuf30 *)(void *)strDest = *(const SecStrBuf30 *)(const void *)strSrc; \ break; \ case 31: \ *(SecStrBuf31 *)(void *)strDest = *(const SecStrBuf31 *)(const void *)strSrc; \ break; \ case 32: \ *(SecStrBuf32 *)(void *)strDest = *(const SecStrBuf32 *)(const void *)strSrc; \ break; \ default: \ break; \ } /* END switch */ \ } else { \ char *tmpStrDest = (char *)strDest; \ const char *tmpStrSrc = (const char *)strSrc; \ switch (srcStrLen) { \ case 32: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 31: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 30: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 29: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 28: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 27: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 26: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 25: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 24: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 23: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 22: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 21: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 20: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 19: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 18: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 17: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 16: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 15: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 14: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 13: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 12: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 11: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 10: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 9: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 8: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 7: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 6: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 5: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 4: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 3: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 2: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ case 1: \ *(tmpStrDest++) = *(tmpStrSrc++); \ /* fall-through */ /* FALLTHRU */ \ default: \ break; \ } \ } \ } SECUREC_WHILE_ZERO #endif /* * Check Src Range */ static errno_t CheckSrcRange(char *strDest, size_t destMax, const char *strSrc) { size_t tmpDestMax = destMax; const char *tmpSrc = strSrc; /* use destMax as boundary checker and destMax must be greater than zero */ while (*(tmpSrc) != '\0' && tmpDestMax > 0) { ++tmpSrc; --tmpDestMax; } if (tmpDestMax == 0) { strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("strcpy_s"); return ERANGE_AND_RESET; } return EOK; } /* * Handling errors */ errno_t strcpy_error(char *strDest, size_t destMax, const char *strSrc) { if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("strcpy_s"); return ERANGE; } else if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("strcpy_s"); if (strDest != NULL) { strDest[0] = '\0'; return EINVAL_AND_RESET; } return EINVAL; } return CheckSrcRange(strDest, destMax, strSrc); } /* * Performance optimization. srcStrLen include '\0' */ static void SecDoStrcpyOpt(char *strDest, const char *strSrc, size_t srcStrLen) { #if SECUREC_IN_KERNEL SecDoMemcpy(strDest, strSrc, srcStrLen); #else if (srcStrLen > SECUREC_STRCOPY_THRESHOLD_SIZE) { SecDoMemcpy(strDest, strSrc, srcStrLen); } else { SECUREC_SMALL_STR_COPY; } #endif } /* * * The strcpy_s function copies the string pointed to strSrc * (including the terminating null character) into the array pointed to by strDest * The destination string must be large enough to hold the source string, * including the terminating null character. strcpy_s will return EOVERLAP_AND_RESET * if the source and destination strings overlap. * * * strDest Location of destination string buffer * destMax Size of the destination string buffer. * strSrc Null-terminated source string buffer. * * * strDest is updated. * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * EINVAL_AND_RESET strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t strcpy_s(char *strDest, size_t destMax, const char *strSrc) { if ((destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN && strDest != NULL && strSrc != NULL && strDest != strSrc)) { size_t srcStrLen = SecStrMinLen(strSrc, destMax) + 1; /* len include \0 */ if (srcStrLen <= destMax) { /* use mem overlap check include \0 */ if (SECUREC_MEMORY_NO_OVERLAP(strDest, strSrc, srcStrLen)) { /* performance optimization srcStrLen include '\0' */ SecDoStrcpyOpt(strDest, strSrc, srcStrLen); return EOK; } else { strDest[0] = '\0'; SECUREC_ERROR_BUFFER_OVERLAP("strcpy_s"); return EOVERLAP_AND_RESET; } } } return strcpy_error(strDest, destMax, strSrc); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(strcpy_s); #endif ================================================ FILE: third_party/securec/src/strncat_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_STR_LEN 1 #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" /* * Befor this function, the basic parameter checking has been done */ static errno_t SecDoStrncat(char *strDest, size_t destMax, const char *strSrc, size_t count) { size_t destLen = SecStrMinLen(strDest, destMax); /* The strSrc is no longer optimized. The reason is that when count is small, * the efficiency of strnlen is higher than that of self realization. */ size_t srcLen = SecStrMinLen(strSrc, count); if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { strDest[0] = '\0'; if (strDest + destLen <= strSrc && destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_BUFFER_OVERLAP("strncat_s"); return EOVERLAP_AND_RESET; } if (srcLen + destLen >= destMax || strDest == strSrc) { strDest[0] = '\0'; if (destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_INVALID_RANGE("strncat_s"); return ERANGE_AND_RESET; } SecDoMemcpy(strDest + destLen, strSrc, srcLen); /* no terminator */ *(strDest + destLen + srcLen) = '\0'; return EOK; } /* * * The strncat_s function appends not more than n successive characters * (not including the terminating null character) * from the array pointed to by strSrc to the end of the string pointed to by strDest * The strncat_s function try to append the first D characters of strSrc to * the end of strDest, where D is the lesser of count and the length of strSrc. * If appending those D characters will fit within strDest (whose size is given * as destMax) and still leave room for a null terminator, then those characters * are appended, starting at the original terminating null of strDest, and a * new terminating null is appended; otherwise, strDest[0] is set to the null * character. * * * strDest Null-terminated destination string. * destMax Size of the destination buffer. * strSrc Null-terminated source string. * count Number of character to append, or truncate. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid)or * (strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN) * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t strncat_s(char *strDest, size_t destMax, const char *strSrc, size_t count) { if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("strncat_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); if (strDest != NULL) { strDest[0] = '\0'; return EINVAL_AND_RESET; } return EINVAL; } if (count > SECUREC_STRING_MAX_LEN) { #ifdef SECUREC_COMPATIBLE_WIN_FORMAT if (count == (size_t)(-1)) { /* Windows internal functions may pass in -1 when calling this function */ return SecDoStrncat(strDest, destMax, strSrc, destMax); } #endif strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("strncat_s"); return ERANGE_AND_RESET; } return SecDoStrncat(strDest, destMax, strSrc, count); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(strncat_s); #endif ================================================ FILE: third_party/securec/src/strncpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_STR_LEN 1 #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" #if defined(SECUREC_COMPATIBLE_WIN_FORMAT) #define SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count) \ (((destMax) > 0 && (destMax) <= SECUREC_STRING_MAX_LEN && (strDest) != NULL && (strSrc) != NULL && \ ((count) <= SECUREC_STRING_MAX_LEN || (count) == ((size_t)(-1))) && (count) > 0)) #else #define SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count) \ (((destMax) > 0 && (destMax) <= SECUREC_STRING_MAX_LEN && (strDest) != NULL && (strSrc) != NULL && \ (count) <= SECUREC_STRING_MAX_LEN && (count) > 0)) #endif /* * Check Src Count Range */ static errno_t CheckSrcCountRange(char *strDest, size_t destMax, const char *strSrc, size_t count) { size_t tmpDestMax = destMax; size_t tmpCount = count; const char *endPos = strSrc; /* use destMax and count as boundary checker and destMax must be greater than zero */ while (*(endPos) != '\0' && tmpDestMax > 0 && tmpCount > 0) { ++endPos; --tmpCount; --tmpDestMax; } if (tmpDestMax == 0) { strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("strncpy_s"); return ERANGE_AND_RESET; } return EOK; } /* * Handling errors, when dest euqal src return EOK */ errno_t strncpy_error(char *strDest, size_t destMax, const char *strSrc, size_t count) { if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("strncpy_s"); return ERANGE; } else if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("strncpy_s"); if (strDest != NULL) { strDest[0] = '\0'; return EINVAL_AND_RESET; } return EINVAL; } else if (count > SECUREC_STRING_MAX_LEN) { strDest[0] = '\0'; /* clear dest string */ SECUREC_ERROR_INVALID_RANGE("strncpy_s"); return ERANGE_AND_RESET; } else if (count == 0) { strDest[0] = '\0'; return EOK; } return CheckSrcCountRange(strDest, destMax, strSrc, count); } /* * * The strncpy_s function copies not more than n successive characters (not including the terminating null character) * from the array pointed to by strSrc to the array pointed to by strDest. * * * strDest Destination string. * destMax The size of the destination string, in characters. * strSrc Source string. * count Number of characters to be copied. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * EINVAL_AND_RESET strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t strncpy_s(char *strDest, size_t destMax, const char *strSrc, size_t count) { if (SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count)) { size_t minCpLen; /* use it to store the maxi length limit */ if (count < destMax) { minCpLen = SecStrMinLen(strSrc, count); /* no ending terminator */ } else { size_t tmpCount = destMax; #ifdef SECUREC_COMPATIBLE_WIN_FORMAT if (count == ((size_t)(-1))) { tmpCount = destMax - 1; } #endif minCpLen = SecStrMinLen(strSrc, tmpCount); if (minCpLen == destMax) { strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("strncpy_s"); return ERANGE_AND_RESET; } } if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, minCpLen) || strDest == strSrc) { /* Not overlap */ SecDoMemcpy(strDest, strSrc, minCpLen); /* copy string without terminator */ strDest[minCpLen] = '\0'; return EOK; } else { strDest[0] = '\0'; SECUREC_ERROR_BUFFER_OVERLAP("strncpy_s"); return EOVERLAP_AND_RESET; } } return strncpy_error(strDest, destMax, strSrc, count); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(strncpy_s); #endif ================================================ FILE: third_party/securec/src/strtok_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * Find beginning of token (skip over leading delimiters).Note that * there is no token if this loop sets string to point to the terminal null. */ static char *SecFindBegin(char *strToken, const char *strDelimit) { char *token = strToken; while (*token != '\0') { const char *ctl = strDelimit; while (*ctl != '\0' && *ctl != *token) { ++ctl; } if (*ctl == '\0') { /* don't find any delimiter in string header, break the loop */ break; } ++token; } return token; } /* * Find rest of token */ static char *SecFindRest(char *strToken, const char *strDelimit) { /* Find the rest of the token. If it is not the end of the string, * put a null there. */ char *token = strToken; while (*token != '\0') { const char *ctl = strDelimit; while (*ctl != '\0' && *ctl != *token) { ++ctl; } if (*ctl != '\0') { /* find a delimiter */ *token++ = '\0'; /* set string termintor */ break; } ++token; } return token; } /* * Find the final position pointer */ static char *SecUpdateToken(char *strToken, const char *strDelimit, char **context) { /* point to updated position */ char *token = SecFindRest(strToken, strDelimit); /* record string position for next search in the context */ *context = token; /* Determine if a token has been found. */ if (token == strToken) { return NULL; } return strToken; } /* * * The strtok_s function parses a string into a sequence of strToken, * replace all characters in strToken string that match to strDelimit set with 0. * On the first call to strtok_s the string to be parsed should be specified in strToken. * In each subsequent call that should parse the same string, strToken should be NULL * * strToken String containing token or tokens. * strDelimit Set of delimiter characters. * context Used to store position information between calls * to strtok_s * * context is updated * * On the first call returns the address of the first non \0 character, otherwise NULL is returned. * In subsequent calls, the strtoken is set to NULL, and the context set is the same as the previous call, * return NULL if the *context string length is equal 0, otherwise return *context. */ char *strtok_s(char *strToken, const char *strDelimit, char **context) { char *orgToken = strToken; /* validate delimiter and string context */ if (context == NULL || strDelimit == NULL) { return NULL; } /* valid input string and string pointer from where to search */ if (orgToken == NULL && (*context) == NULL) { return NULL; } /* If string is null, continue searching from previous string position stored in context */ if (orgToken == NULL) { orgToken = *context; } orgToken = SecFindBegin(orgToken, strDelimit); return SecUpdateToken(orgToken, strDelimit, context); } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(strtok_s); #endif ================================================ FILE: third_party/securec/src/swprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The swprintf_s function is the wide-character equivalent of the sprintf_s function * * * strDest Storage location for the output. * destMax Maximum number of characters to store. * format Format-control string. * ... Optional arguments * * * strDest is updated * * * return the number of wide characters stored in strDest, not counting the terminating null wide character. * return -1 if an error occurred. * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int swprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vswprintf_s(strDest, destMax, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } ================================================ FILE: third_party/securec/src/swscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * The swscanf_s function is the wide-character equivalent of the sscanf_s function * The swscanf_s function reads data from buffer into the location given by * each argument. Every argument must be a pointer to a variable with a type * that corresponds to a type specifier in format. The format argument controls * the interpretation of the input fields and has the same form and function * as the format argument for the scanf function. If copying takes place between * strings that overlap, the behavior is undefined. * * * buffer Stored data. * format Format control string, see Format Specifications. * ... Optional arguments. * * * ... the converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; The return value does not include fields that were read but not * assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int swscanf_s(const wchar_t *buffer, const wchar_t *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vswscanf_s(buffer, format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; } ================================================ FILE: third_party/securec/src/vfscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" /* * * The vfscanf_s function is equivalent to fscanf_s, with the variable argument list replaced by argList * The vfscanf_s function reads data from the current position of stream into * the locations given by argument (if any). Each argument must be a pointer * to a variable of a type that corresponds to a type specifier in format. * format controls the interpretation of the input fields and has the same * form and function as the format argument for scanf. * * * stream Pointer to FILE structure. * format Format control string, see Format Specifications. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vfscanf_s(FILE *stream, const char *format, va_list argList) { int retVal; /* If initialization causes e838 */ SecFileStream fStr; if ((stream == NULL) || (format == NULL)) { SECUREC_ERROR_INVALID_PARAMTER("vfscanf_s"); return SECUREC_SCANF_EINVAL; } if (stream == stdin) { return vscanf_s(format, argList); } SECUREC_LOCK_FILE(stream); SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FILE_STREAM_FLAG, stream, SECUREC_UNINITIALIZED_FILE_POS, NULL, 0); retVal = SecInputS(&fStr, format, argList); SECUREC_UNLOCK_FILE(stream); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vfscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } ================================================ FILE: third_party/securec/src/vfwscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" /* * * The vfwscanf_s function is the wide-character equivalent of the vfscanf_s function * The vfwscanf_s function reads data from the current position of stream into * the locations given by argument (if any). Each argument must be a pointer * to a variable of a type that corresponds to a type specifier in format. * format controls the interpretation of the input fields and has the same form * and function as the format argument for scanf. * * * stream Pointer to FILE structure. * format Format control string, see Format Specifications. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vfwscanf_s(FILE *stream, const wchar_t *format, va_list argList) { int retVal; /* If initialization causes e838 */ SecFileStream fStr; if ((stream == NULL) || (format == NULL)) { SECUREC_ERROR_INVALID_PARAMTER("vfwscanf_s"); return SECUREC_SCANF_EINVAL; } if (stream == stdin) { return vwscanf_s(format, argList); } SECUREC_LOCK_FILE(stream); SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FILE_STREAM_FLAG, stream, SECUREC_UNINITIALIZED_FILE_POS, NULL, 0); retVal = SecInputSW(&fStr, format, argList); SECUREC_UNLOCK_FILE(stream); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vfwscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } ================================================ FILE: third_party/securec/src/vscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" /* * * The vscanf_s function is equivalent to scanf_s, with the variable argument list replaced by argList, * The vscanf_s function reads data from the standard input stream stdin and * writes the data into the location that's given by argument. Each argument * must be a pointer to a variable of a type that corresponds to a type specifier * in format. If copying occurs between strings that overlap, the behavior is * undefined. * * * format Format control string. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Returns the number of fields successfully converted and assigned; * the return value does not include fields that were read but not assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vscanf_s(const char *format, va_list argList) { int retVal; /* If initialization causes e838 */ SecFileStream fStr; SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FROM_STDIN_FLAG, stdin, 0, NULL, 0); /* * "va_list" has different definition on different platform, so we can't use argList == NULL * to determine it's invalid. If you has fixed platform, you can check some fields to validate it, * such as "argList == NULL" or argList.xxx != NULL or *(size_t *)&argList != 0. */ if (format == NULL || fStr.pf == NULL) { SECUREC_ERROR_INVALID_PARAMTER("vscanf_s"); return SECUREC_SCANF_EINVAL; } SECUREC_LOCK_STDIN(0, fStr.pf); retVal = SecInputS(&fStr, format, argList); SECUREC_UNLOCK_STDIN(0, fStr.pf); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } ================================================ FILE: third_party/securec/src/vsnprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secureprintoutput.h" #if SECUREC_ENABLE_VSNPRINTF /* * * The vsnprintf_s function is equivalent to the vsnprintf function * except for the parameter destMax/count and the explicit runtime-constraints violation * The vsnprintf_s function takes a pointer to an argument list, then formats * and writes up to count characters of the given data to the memory pointed * to by strDest and appends a terminating null. * * * strDest Storage location for the output. * destMax The size of the strDest for output. * count Maximum number of character to write(not including * the terminating NULL) * format Format-control string. * argList pointer to list of arguments. * * * strDest is updated * * * return the number of characters written, not including the terminating null * return -1 if an error occurs. * return -1 if count < destMax and the output string has been truncated * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int vsnprintf_s(char *strDest, size_t destMax, size_t count, const char *format, va_list argList) { int retVal; if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN || (count > (SECUREC_STRING_MAX_LEN - 1) && count != (size_t)(-1))) { if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { strDest[0] = '\0'; } SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_s"); return -1; } if (destMax > count) { retVal = SecVsnprintfImpl(strDest, count + 1, format, argList); if (retVal == SECUREC_PRINTF_TRUNCATE) { /* lsd add to keep dest buffer not destroyed 2014.2.18 */ /* the string has been truncated, return -1 */ return -1; /* to skip error handler, return strlen(strDest) or -1 */ } } else { retVal = SecVsnprintfImpl(strDest, destMax, format, argList); #ifdef SECUREC_COMPATIBLE_WIN_FORMAT if (retVal == SECUREC_PRINTF_TRUNCATE && count == (size_t)(-1)) { return -1; } #endif } if (retVal < 0) { strDest[0] = '\0'; /* empty the dest strDest */ if (retVal == SECUREC_PRINTF_TRUNCATE) { /* Buffer too small */ SECUREC_ERROR_INVALID_RANGE("vsnprintf_s"); } SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_s"); return -1; } return retVal; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(vsnprintf_s); #endif #endif #if SECUREC_SNPRINTF_TRUNCATED /* * * The vsnprintf_truncated_s function is equivalent to the vsnprintf function * except for the parameter destMax/count and the explicit runtime-constraints violation * The vsnprintf_truncated_s function takes a pointer to an argument list, then formats * and writes up to count characters of the given data to the memory pointed * to by strDest and appends a terminating null. * * * strDest Storage location for the output. * destMax The size of the strDest for output. * the terminating NULL) * format Format-control string. * argList pointer to list of arguments. * * * strDest is updated * * * return the number of characters written, not including the terminating null * return -1 if an error occurs. * return destMax-1 if output string has been truncated * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int vsnprintf_truncated_s(char *strDest, size_t destMax, const char *format, va_list argList) { int retVal; if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { strDest[0] = '\0'; } SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_truncated_s"); return -1; } retVal = SecVsnprintfImpl(strDest, destMax, format, argList); if (retVal < 0) { if (retVal == SECUREC_PRINTF_TRUNCATE) { return (int)(destMax - 1); /* to skip error handler, return strlen(strDest) */ } strDest[0] = '\0'; /* empty the dest strDest */ SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_truncated_s"); return -1; } return retVal; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(vsnprintf_truncated_s); #endif #endif ================================================ FILE: third_party/securec/src/vsprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secureprintoutput.h" /* * * The vsprintf_s function is equivalent to the vsprintf function * except for the parameter destMax and the explicit runtime-constraints violation * The vsprintf_s function takes a pointer to an argument list, and then formats * and writes the given data to the memory pointed to by strDest. * The function differ from the non-secure versions only in that the secure * versions support positional parameters. * * * strDest Storage location for the output. * destMax Size of strDest * format Format specification. * argList pointer to list of arguments * * * strDest is updated * * * return the number of characters written, not including the terminating null character, * return -1 if an error occurs. * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int vsprintf_s(char *strDest, size_t destMax, const char *format, va_list argList) { int retVal; /* If initialization causes e838 */ if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { strDest[0] = '\0'; } SECUREC_ERROR_INVALID_PARAMTER("vsprintf_s"); return -1; } retVal = SecVsnprintfImpl(strDest, destMax, format, argList); if (retVal < 0) { strDest[0] = '\0'; if (retVal == SECUREC_PRINTF_TRUNCATE) { /* Buffer is too small */ SECUREC_ERROR_INVALID_RANGE("vsprintf_s"); } SECUREC_ERROR_INVALID_PARAMTER("vsprintf_s"); return -1; } return retVal; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(vsprintf_s); #endif ================================================ FILE: third_party/securec/src/vsscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" #if defined(SECUREC_VXWORKS_PLATFORM) && (!defined(SECUREC_SYSAPI4VXWORKS) && !defined(SECUREC_CTYPE_MACRO_ADAPT)) #include #endif /* * * vsscanf_s * * * * The vsscanf_s function is equivalent to sscanf_s, with the variable argument list replaced by argList * The vsscanf_s function reads data from buffer into the location given by * each argument. Every argument must be a pointer to a variable with a type * that corresponds to a type specifier in format. The format argument controls * the interpretation of the input fields and has the same form and function * as the format argument for the scanf function. * If copying takes place between strings that overlap, the behavior is undefined. * * * buffer Stored data * format Format control string, see Format Specifications. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vsscanf_s(const char *buffer, const char *format, va_list argList) { size_t count; /* If initialization causes e838 */ int retVal; SecFileStream fStr; /* validation section */ if (buffer == NULL || format == NULL) { SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); return SECUREC_SCANF_EINVAL; } count = strlen(buffer); if (count == 0 || count > SECUREC_STRING_MAX_LEN) { SecClearDestBuf(buffer, format, argList); SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); return SECUREC_SCANF_EINVAL; } #ifdef SECUREC_VXWORKS_PLATFORM /* * in vxworks platform when buffer is white string, will set first %s argument tu zero.like following useage: * " \v\f\t\r\n", "%s", str, strSize * do not check all character, just first and last character then consider it is white string */ if (isspace((int)buffer[0]) && isspace((int)buffer[count - 1])) { SecClearDestBuf(buffer, format, argList); } #endif SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_MEM_STR_FLAG, NULL, 0, buffer, (int)count); retVal = SecInputS(&fStr, format, argList); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } #if SECUREC_IN_KERNEL EXPORT_SYMBOL(vsscanf_s); #endif ================================================ FILE: third_party/securec/src/vswprintf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secureprintoutput.h" /* * * The vswprintf_s function is the wide-character equivalent of the vsprintf_s function * * * strDest Storage location for the output. * destMax Size of strDest * format Format specification. * argList pointer to list of arguments * * * strDest is updated * * * return the number of wide characters stored in strDest, not counting the terminating null wide character. * return -1 if an error occurred. * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ int vswprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, va_list argList) { int retVal; /* If initialization causes e838 */ if (format == NULL || strDest == NULL || destMax == 0 || destMax > (SECUREC_WCHAR_STRING_MAX_LEN)) { if (strDest != NULL && destMax > 0) { strDest[0] = '\0'; } SECUREC_ERROR_INVALID_PARAMTER("vswprintf_s"); return -1; } retVal = SecVswprintfImpl(strDest, destMax, format, argList); if (retVal < 0) { strDest[0] = '\0'; if (retVal == SECUREC_PRINTF_TRUNCATE) { /* Buffer too small */ SECUREC_ERROR_INVALID_RANGE("vswprintf_s"); } SECUREC_ERROR_INVALID_PARAMTER("vswprintf_s"); return -1; } return retVal; } ================================================ FILE: third_party/securec/src/vswscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" static size_t SecWcslen(const wchar_t *s) { const wchar_t *end = s; while (*end != L'\0') { ++end; } return ((size_t)((end - s))); } /* * * The vswscanf_s function is the wide-character equivalent of the vsscanf_s function * The vsscanf_s function reads data from buffer into the location given by * each argument. Every argument must be a pointer to a variable with a type * that corresponds to a type specifier in format. * The format argument controls the interpretation of the input fields and * has the same form and function as the format argument for the scanf function. * If copying takes place between strings that overlap, the behavior is undefined. * * * buffer Stored data * format Format control string, see Format Specifications. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Each of these functions returns the number of fields successfully converted * and assigned; the return value does not include fields that were read but * not assigned. A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vswscanf_s(const wchar_t *buffer, const wchar_t *format, va_list argList) { size_t count; /* If initialization causes e838 */ SecFileStream fStr; int retVal; /* validation section */ if (buffer == NULL || format == NULL) { SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); return SECUREC_SCANF_EINVAL; } count = SecWcslen(buffer); if (count == 0 || count > SECUREC_WCHAR_STRING_MAX_LEN) { SecClearDestBufW(buffer, format, argList); SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); return SECUREC_SCANF_EINVAL; } SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_MEM_STR_FLAG, NULL, 0,\ (const char *)buffer, (int)count * ((int)sizeof(wchar_t))); retVal = SecInputSW(&fStr, format, argList); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } ================================================ FILE: third_party/securec/src/vwscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "secinput.h" /* * * The vwscanf_s function is the wide-character equivalent of the vscanf_s function * The vwscanf_s function is the wide-character version of vscanf_s. The * function reads data from the standard input stream stdin and writes the * data into the location that's given by argument. Each argument must be a * pointer to a variable of a type that corresponds to a type specifier in * format. If copying occurs between strings that overlap, the behavior is * undefined. * * * format Format control string. * argList pointer to list of arguments * * * argList the converted value stored in user assigned address * * * Returns the number of fields successfully converted and assigned; * the return value does not include fields that were read but not assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int vwscanf_s(const wchar_t *format, va_list argList) { int retVal; /* If initialization causes e838 */ SecFileStream fStr; SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FROM_STDIN_FLAG, stdin, 0, NULL, 0); if (format == NULL || fStr.pf == NULL) { SECUREC_ERROR_INVALID_PARAMTER("vwscanf_s"); return SECUREC_SCANF_EINVAL; } SECUREC_LOCK_STDIN(0, fStr.pf); retVal = SecInputSW(&fStr, format, argList); SECUREC_UNLOCK_STDIN(0, fStr.pf); if (retVal < 0) { SECUREC_ERROR_INVALID_PARAMTER("vwscanf_s"); return SECUREC_SCANF_EINVAL; } return retVal; } ================================================ FILE: third_party/securec/src/wcscat_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" /* * Befor this function, the basic parameter checking has been done */ static errno_t SecDoWcscat(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) { size_t destLen; size_t srcLen; size_t maxCount; /* Store the maximum available count */ /* To calculate the length of a wide character, the parameter must be a wide character */ SECUREC_CALC_WSTR_LEN(strDest, destMax, &destLen); maxCount = destMax - destLen; SECUREC_CALC_WSTR_LEN(strSrc, maxCount, &srcLen); if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { strDest[0] = L'\0'; if (strDest + destLen <= strSrc && destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_BUFFER_OVERLAP("wcscat_s"); return EOVERLAP_AND_RESET; } if (srcLen + destLen >= destMax || strDest == strSrc) { strDest[0] = L'\0'; if (destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_INVALID_RANGE("wcscat_s"); return ERANGE_AND_RESET; } SecDoMemcpy(strDest + destLen, strSrc, (srcLen + 1) * sizeof(wchar_t)); /* single character length include \0 */ return EOK; } /* * * The wcscat_s function appends a copy of the wide string pointed to by strSrc * (including the terminating null wide character) * to the end of the wide string pointed to by strDest. * The arguments and return value of wcscat_s are wide-character strings. * * The wcscat_s function appends strSrc to strDest and terminates the resulting * string with a null character. The initial character of strSrc overwrites the * terminating null character of strDest. wcscat_s will return EOVERLAP_AND_RESET if the * source and destination strings overlap. * * Note that the second parameter is the total size of the buffer, not the * remaining size. * * * strDest Null-terminated destination string buffer. * destMax Size of the destination string buffer. * strSrc Null-terminated source string buffer. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid) or * (strDest != NULL and strSrc is NULLL and destMax != 0 * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN) * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t wcscat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) { if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("wcscat_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); if (strDest != NULL) { strDest[0] = L'\0'; return EINVAL_AND_RESET; } return EINVAL; } return SecDoWcscat(strDest, destMax, strSrc); } ================================================ FILE: third_party/securec/src/wcscpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" static errno_t SecDoWcscpy(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) { size_t srcStrLen; SECUREC_CALC_WSTR_LEN(strSrc, destMax, &srcStrLen); if (srcStrLen == destMax) { strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("wcscpy_s"); return ERANGE_AND_RESET; } if (strDest == strSrc) { return EOK; } if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, srcStrLen)) { /* performance optimization srcStrLen include '\0' */ SecDoMemcpy(strDest, strSrc, (srcStrLen + 1) * sizeof(wchar_t)); /* single character length include \0 */ return EOK; } else { strDest[0] = L'\0'; SECUREC_ERROR_BUFFER_OVERLAP("wcscpy_s"); return EOVERLAP_AND_RESET; } } /* * * The wcscpy_s function copies the wide string pointed to by strSrc * (including theterminating null wide character) into the array pointed to by strDest * * strDest Destination string buffer * destMax Size of the destination string buffer. * strSrc Null-terminated source string buffer. * * * strDest is updated. * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * EINVAL_AND_RESET strDest != NULL and strSrc is NULLL and destMax != 0 * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 * ERANGE_AND_RESET destMax <= length of strSrc and strDest != strSrc * and strDest != NULL and strSrc != NULL and destMax != 0 * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and destMax != 0 * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * and strDest != NULL and strSrc !=NULL and strDest != strSrc * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t wcscpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) { if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("wcscpy_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("wcscpy_s"); if (strDest != NULL) { strDest[0] = L'\0'; return EINVAL_AND_RESET; } return EINVAL; } return SecDoWcscpy(strDest, destMax, strSrc); } ================================================ FILE: third_party/securec/src/wcsncat_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" /* * Befor this function, the basic parameter checking has been done */ static errno_t SecDoWcsncat(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) { size_t destLen; size_t srcLen; /* To calculate the length of a wide character, the parameter must be a wide character */ SECUREC_CALC_WSTR_LEN(strDest, destMax, &destLen); SECUREC_CALC_WSTR_LEN(strSrc, count, &srcLen); if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { strDest[0] = L'\0'; if (strDest + destLen <= strSrc && destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_BUFFER_OVERLAP("wcsncat_s"); return EOVERLAP_AND_RESET; } if (srcLen + destLen >= destMax || strDest == strSrc) { strDest[0] = L'\0'; if (destLen == destMax) { SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); return EINVAL_AND_RESET; } SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); return ERANGE_AND_RESET; } SecDoMemcpy(strDest + destLen, strSrc, srcLen * sizeof(wchar_t)); /* no terminator */ *(strDest + destLen + srcLen) = L'\0'; return EOK; } /* * * The wcsncat_s function appends not more than n successive wide characters * (not including the terminating null wide character) * from the array pointed to by strSrc to the end of the wide string pointed to by strDest. * * The wcsncat_s function try to append the first D characters of strSrc to * the end of strDest, where D is the lesser of count and the length of strSrc. * If appending those D characters will fit within strDest (whose size is * given as destMax) and still leave room for a null terminator, then those * characters are appended, starting at the original terminating null of * strDest, and a new terminating null is appended; otherwise, strDest[0] is * set to the null character. * * * strDest Null-terminated destination string. * destMax Size of the destination buffer. * strSrc Null-terminated source string. * count Number of character to append, or truncate. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid) or * (strDest != NULL and strSrc is NULLL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN) * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t wcsncat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) { if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); if (strDest != NULL) { strDest[0] = L'\0'; return EINVAL_AND_RESET; } return EINVAL; } if (count > SECUREC_WCHAR_STRING_MAX_LEN) { #ifdef SECUREC_COMPATIBLE_WIN_FORMAT if (count == ((size_t)-1)) { /* Windows internal functions may pass in -1 when calling this function */ return SecDoWcsncat(strDest, destMax, strSrc, destMax); } #endif strDest[0] = L'\0'; SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); return ERANGE_AND_RESET; } return SecDoWcsncat(strDest, destMax, strSrc, count); } ================================================ FILE: third_party/securec/src/wcsncpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define SECUREC_INLINE_DO_MEMCPY 1 #include "securecutil.h" static errno_t SecDoWcsncpy(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) { size_t srcStrLen; if (count < destMax) { SECUREC_CALC_WSTR_LEN(strSrc, count, &srcStrLen); } else { SECUREC_CALC_WSTR_LEN(strSrc, destMax, &srcStrLen); } if (srcStrLen == destMax) { strDest[0] = '\0'; SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); return ERANGE_AND_RESET; } if (strDest == strSrc) { return EOK; } if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, srcStrLen)) { /* performance optimization srcStrLen not include '\0' */ SecDoMemcpy(strDest, strSrc, srcStrLen * sizeof(wchar_t)); *(strDest + srcStrLen) = L'\0'; return EOK; } else { strDest[0] = L'\0'; SECUREC_ERROR_BUFFER_OVERLAP("wcsncpy_s"); return EOVERLAP_AND_RESET; } } /* * * The wcsncpy_s function copies not more than n successive wide characters * (not including the terminating null wide character) * from the array pointed to by strSrc to the array pointed to by strDest * * * strDest Destination string. * destMax The size of the destination string, in characters. * strSrc Source string. * count Number of characters to be copied. * * * strDest is updated * * * EOK Success * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * EINVAL_AND_RESET strDest != NULL and strSrc is NULLL and destMax != 0 * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 * ERANGE_AND_RESET count > SECUREC_WCHAR_STRING_MAX_LEN or * (destMax <= length of strSrc and destMax <= count and strDest != strSrc * and strDest != NULL and strSrc != NULL and destMax != 0 and * destMax <= SECUREC_WCHAR_STRING_MAX_LEN and not overlap) * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid * * * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid */ errno_t wcsncpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) { if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); return ERANGE; } if (strDest == NULL || strSrc == NULL) { SECUREC_ERROR_INVALID_PARAMTER("wcsncpy_s"); if (strDest != NULL) { strDest[0] = '\0'; return EINVAL_AND_RESET; } return EINVAL; } if (count > SECUREC_WCHAR_STRING_MAX_LEN) { #ifdef SECUREC_COMPATIBLE_WIN_FORMAT if (count == (size_t)(-1)) { return SecDoWcsncpy(strDest, destMax, strSrc, destMax - 1); } #endif strDest[0] = '\0'; /* clear dest string */ SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); return ERANGE_AND_RESET; } if (count == 0) { strDest[0] = '\0'; return EOK; } return SecDoWcsncpy(strDest, destMax, strSrc, count); } ================================================ FILE: third_party/securec/src/wcstok_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * FindBegin Wide character postion function */ static wchar_t *SecFindBeginW(wchar_t *strToken, const wchar_t *strDelimit) { /* Find beginning of token (skip over leading delimiters). Note that * there is no token if this loop sets string to point to the terminal null. */ wchar_t *token = strToken; while (*token != L'\0') { const wchar_t *ctl = strDelimit; while (*ctl != L'\0' && *ctl != *token) { ++ctl; } if (*ctl == L'\0') { break; } ++token; } return token; } /* * FindBegin rest Wide character postion function */ static wchar_t *SecFindRestW(wchar_t *strToken, const wchar_t *strDelimit) { /* Find the end of the token. If it is not the end of the string, * put a null there. */ wchar_t *token = strToken; while (*token != L'\0') { const wchar_t *ctl = strDelimit; while (*ctl != L'\0' && *ctl != *token) { ++ctl; } if (*ctl != L'\0') { *token++ = L'\0'; break; } ++token; } return token; } /* * Update Token wide character function */ static wchar_t *SecUpdateTokenW(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context) { /* point to updated position */ wchar_t *token = SecFindRestW(strToken, strDelimit); /* Update the context */ *context = token; /* Determine if a token has been found. */ if (token == strToken) { return NULL; } return strToken; } /* * * wcstok_s * * * * The wcstok_s function is the wide-character equivalent of the strtok_s function * * * strToken String containing token or tokens. * strDelimit Set of delimiter characters. * context Used to store position information between calls to * wcstok_s. * * * context is updated * * The wcstok_s function is the wide-character equivalent of the strtok_s function */ wchar_t *wcstok_s(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context) { wchar_t *orgToken = strToken; /* validation section */ if (context == NULL || strDelimit == NULL) { return NULL; } if (orgToken == NULL && (*context) == NULL) { return NULL; } /* If string==NULL, continue with previous string */ if (orgToken == NULL) { orgToken = *context; } orgToken = SecFindBeginW(orgToken, strDelimit); return SecUpdateTokenW(orgToken, strDelimit, context); } ================================================ FILE: third_party/securec/src/wmemcpy_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securecutil.h" /* * * The wmemcpy_s function copies n successive wide characters * from the object pointed to by src into the object pointed to by dest.t. * * * dest Destination buffer. * destMax Size of the destination buffer. * src Buffer to copy from. * count Number of characters to copy. * * * dest buffer is uptdated. * * * EOK Success * EINVAL dest is NULL and destMax != 0 and count <= destMax * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN and count <= destMax * ERANGE destMax > SECUREC_WCHAR_MEM_MAX_LEN or destMax is 0 or * (count > destMax and dest is NULL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN) * ERANGE_AND_RESET count > destMax and dest != NULL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and * count <= destMax destMax != 0 and destMax <= SECUREC_WCHAR_MEM_MAX_LEN * and dest != NULL and src != NULL and dest != src * * if an error occured, dest will be filled with 0 when dest and destMax valid . * If the source and destination overlap, the behavior of wmemcpy_s is undefined. * Use wmemmove_s to handle overlapping regions. */ errno_t wmemcpy_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count) { if (destMax == 0 || destMax > SECUREC_WCHAR_MEM_MAX_LEN) { SECUREC_ERROR_INVALID_PARAMTER("wmemcpy_s"); return ERANGE; } if (count > destMax) { SECUREC_ERROR_INVALID_PARAMTER("wmemcpy_s"); if (dest != NULL) { (void)memset(dest, 0, destMax * sizeof(wchar_t)); return ERANGE_AND_RESET; } return ERANGE; } return memcpy_s(dest, destMax * sizeof(wchar_t), src, count * sizeof(wchar_t)); } ================================================ FILE: third_party/securec/src/wmemmove_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securecutil.h" /* * * The wmemmove_s function copies n successive wide characters from the object pointed * to by src into the object pointed to by dest. * * * dest Destination buffer. * destMax Size of the destination buffer. * src Source object. * count Number of bytes or character to copy. * * * dest is updated. * * * EOK Success * EINVAL dest is NULL and destMax != 0 and count <= destMax * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN and count <= destMax * ERANGE destMax > SECUREC_WCHAR_MEM_MAX_LEN or destMax is 0 or * (count > destMax and dest is NULL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN) * ERANGE_AND_RESET count > destMax and dest != NULL and destMax != 0 * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN * * * If an error occured, dest will be filled with 0 when dest and destMax valid. * If some regions of the source area and the destination overlap, wmemmove_s * ensures that the original source bytes in the overlapping region are copied * before being overwritten */ errno_t wmemmove_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count) { if (destMax == 0 || destMax > SECUREC_WCHAR_MEM_MAX_LEN) { SECUREC_ERROR_INVALID_PARAMTER("wmemmove_s"); return ERANGE; } if (count > destMax) { SECUREC_ERROR_INVALID_PARAMTER("wmemmove_s"); if (dest != NULL) { (void)memset(dest, 0, destMax * sizeof(wchar_t)); return ERANGE_AND_RESET; } return ERANGE; } return memmove_s(dest, destMax * sizeof(wchar_t), src, count * sizeof(wchar_t)); } ================================================ FILE: third_party/securec/src/wscanf_s.c ================================================ /** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "securec.h" /* * * * The wscanf_s function is the wide-character equivalent of the scanf_s function * The wscanf_s function reads data from the standard input stream stdin and * writes the data into the location that's given by argument. Each argument * must be a pointer to a variable of a type that corresponds to a type specifier * in format. If copying occurs between strings that overlap, the behavior is * undefined. * * * format Format control string. * ... Optional arguments. * * * ... the converted value stored in user assigned address * * * Returns the number of fields successfully converted and assigned; * the return value does not include fields that were read but not assigned. * A return value of 0 indicates that no fields were assigned. * return -1 if an error occurs. */ int wscanf_s(const wchar_t *format, ...) { int ret; /* If initialization causes e838 */ va_list argList; va_start(argList, format); ret = vwscanf_s(format, argList); va_end(argList); (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ return ret; }